#include <torch/csrc/autograd/FunctionsManual.h>
#include <torch/csrc/autograd/functions/basic_ops.h>
#include <torch/csrc/autograd/functions/utils.h>
#include <torch/csrc/autograd/variable.h>

#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/ExpandUtils.h>
#include <ATen/LegacyBatchedTensorImpl.h>
#include <ATen/ScalarOps.h>
#include <ATen/SparseCsrTensorUtils.h>
#include <ATen/TensorSubclassLikeUtils.h>
#include <ATen/Utils.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/WrapDimUtilsMulti.h>
#include <ATen/core/Reduction.h>
#include <ATen/core/grad_mode.h>
#include <ATen/native/Activation.h>
#include <ATen/native/IndexingUtils.h>
#include <ATen/native/LinearAlgebraUtils.h>
#include <ATen/native/SparseTensorUtils.h>
#include <ATen/native/nested/NestedTensorUtils.h>
#include <c10/core/TensorOptions.h>
#include <c10/util/OptionalArrayRef.h>
#include <c10/util/SmallBuffer.h>
#include <c10/util/accumulate.h>
#include <c10/util/irange.h>

#include <algorithm>
#include <ciso646>
#include <functional>
#include <numeric>
#include <utility>

// Helper functions for autogenerated code
// These used to be inlined into the codegened Functions.cpp

namespace torch {
namespace autograd {
namespace generated {
namespace details {

using at::areAnyTensorSubclassLike;
using at::IntArrayRef;
using at::OptionalIntArrayRef;
using at::Scalar;
using at::Tensor;
using at::TensorList;

const char* kCudnnDoubleBackwardMsg =
    "Double backwards is not supported for CuDNN RNNs due to limitations in the CuDNN API. To run double backwards, please disable the CuDNN backend temporarily while running the forward pass of your RNN. For example: \nwith torch.backends.cudnn.flags(enabled=False):\n    output = model(inputs)";

Tensor apply_loss_reduction(const Tensor& unreduced, int64_t reduction) {
  if (reduction == at::Reduction::Mean) {
    return unreduced.mean();
  } else if (reduction == at::Reduction::Sum) {
    return unreduced.sum();
  }
  return unreduced;
}

static bool isDefined(const std::optional<Tensor>& t) {
  return t.has_value() && t->defined();
}

Tensor toNonOptTensor(const std::optional<Tensor>& t) {
  return t.has_value() ? *t : Tensor();
}

Tensor toNonOptFwGrad(const std::optional<Tensor>& t) {
  return (t.has_value() && t->defined()) ? t->_fw_grad(/*level */ 0) : Tensor();
}

Tensor toNonOptPrimal(const std::optional<Tensor>& t) {
  if (t.has_value() && t->defined()) {
    if (t->unsafeGetTensorImpl()->is_wrapped_number()) {
      return *t;
    }
    return t->_fw_primal(/* level */ 0);
  }
  return Tensor();
}

void copy_range(variable_list& out, IndexRange range, const Tensor& t) {
  TORCH_CHECK(range.second <= out.size());
  TORCH_CHECK(
      range.second - range.first == 1, "inconsistent range for Tensor output");
  out[range.first] = t;
}

void copy_range(variable_list& out, IndexRange range, at::ArrayRef<Tensor> t) {
  TORCH_CHECK(range.second <= out.size());
  TORCH_CHECK(
      range.second - range.first == t.size(),
      "inconsistent range for TensorList output");
  std::copy(
      t.begin(), t.end(), out.begin() + static_cast<int64_t>(range.first));
}

Tensor copysign_tensor_self_backward(
    const Tensor& grad,
    const Tensor& self,
    const Tensor& result) {
  auto ratio = result / self;
  ratio.masked_fill_(self == 0, 0);
  return grad * ratio;
}

template <typename T>
T not_implemented_base(const char* name, const char* reason) {
  std::string msg =
      c10::str("the derivative for '", name, "' is not implemented.");
  if (reason[0] != '\0') {
    msg = c10::str(msg, " ", reason);
  };
  TORCH_CHECK_NOT_IMPLEMENTED(false, msg);
}

Tensor not_implemented(const char* name, const char* reason) {
  return not_implemented_base<Tensor>(name, reason);
}

std::vector<Tensor> not_implemented_list(const char* name, const char* reason) {
  return not_implemented_base<std::vector<Tensor>>(name, reason);
}

Tensor maybe_multiply(const Tensor& t, const Scalar& s) {
  bool is_one = false;
  if (s.isFloatingPoint()) {
    is_one = s.toSymFloat() == 1;
  } else if (s.isIntegral(true)) {
    is_one = s.toSymInt() == 1;
  }

  if (is_one) {
    return t;
  } else {
    return t * s;
  }
}

int64_t _safe_size(IntArrayRef sizes, IntArrayRef dim) {
  int64_t size = 1;
  if (sizes.empty()) {
    return 1;
  }
  for (auto d : dim) {
    d = at::maybe_wrap_dim(d, static_cast<int64_t>(sizes.size()));
    size *= sizes[d];
  }
  return size;
}

static c10::SymInt _safe_size(c10::SymIntArrayRef sizes, c10::IntArrayRef dim) {
  c10::SymInt size = 1;
  if (sizes.empty()) {
    return 1;
  }
  for (auto d : dim) {
    d = at::maybe_wrap_dim(d, static_cast<int64_t>(sizes.size()));
    size *= sizes[d];
  }
  return size;
}

Tensor handle_r_to_c(ScalarType self_st, Tensor gradient_result) {
  if (!at::isComplexType(self_st) && gradient_result.is_complex()) {
    // R -> C
    return at::real(gradient_result);
  }
  return gradient_result;
}

static Tensor handle_r_to_c(const Tensor& self, Tensor gradient_result) {
  if (!self.is_complex() && gradient_result.is_complex()) {
    // R -> C
    return at::real(gradient_result);
  }
  return gradient_result;
}

Tensor restore_reduced_dims(
    const Tensor& output,
    IntArrayRef dims,
    bool keepdim) {
  if (keepdim) {
    return output;
  }
  auto total_dims = output.dim() + dims.size();
  std::vector<c10::SymInt> target_shape(total_dims, 0);
  for (int64_t i : dims) {
    if (i < 0) {
      i = static_cast<int64_t>(total_dims) + i;
    }
    target_shape[i] = 1;
  }
  int64_t j = 0;
  for (const c10::SymInt& i : output.sym_sizes()) {
    while (target_shape[j] > 0)
      j++;
    target_shape[j++] = i;
  }
  return output.reshape_symint(target_shape);
}

Tensor scale_grad_by_count(
    const Tensor& grad,
    const Tensor& mask,
    IntArrayRef dims) {
  return (grad / mask.sum(dims, true)) * mask;
}

Tensor amaxamin_jvp(
    const Tensor& x,
    const Tensor& dx,
    const Tensor& result,
    IntArrayRef dim,
    bool keepdim) {
  auto mask = x == restore_reduced_dims(result, dim, keepdim);
  return at::where(mask, dx, 0.).sum(dim, keepdim) / mask.sum(dim, keepdim);
}

std::tuple<Tensor, Tensor> _euclidean_dist_backward(
    const Tensor& grad,
    const Tensor& x1,
    const Tensor& x2,
    const Tensor& res) {
  if (!grad.defined()) {
    return std::tuple<Tensor, Tensor>(Tensor(), Tensor());
  }
  // handle case at 0 where we return a subgradient containing 0
  Tensor ratio = grad / res;
  ratio.masked_fill_(res == 0, 0);
  return std::tuple<Tensor, Tensor>{
      x1 * ratio.sum(-1, true) - ratio.matmul(x2),
      x2 * ratio.sum(-2, false).unsqueeze(-1) - ratio.mT().matmul(x1)};
}

Tensor norm_backward(
    const Tensor& grad,
    const Tensor& self,
    const std::optional<Scalar>& p_,
    const Tensor& norm) {
  return norm_backward(grad, self, p_, norm, {}, true);
}

Tensor norm_backward(
    Tensor grad,
    const Tensor& self,
    const std::optional<Scalar>& p_,
    Tensor norm,
    IntArrayRef dim,
    bool keepdim) {
  // NB: We mask fill the NaNs in the output to be zero but still do float
  // division
  //     by zero, which ASAN complains about. One way to appease ASAN is to fill
  //     the problematic values with something arbitrary before the division,
  //     but we decide not to due to the perf hit. Instead we just silence ASAN
  //     where necessary
  size_t ndim = self.dim();
  double p = p_.value_or(2.0).toDouble();
  Tensor self_scaled;
  Tensor scale_v;

  if (!keepdim && self.dim() != 0) {
    grad = unsqueeze_multiple(grad, dim, ndim);
    norm = unsqueeze_multiple(norm, dim, ndim);
  }

  if (p == 0.0) {
    return {};
  } else if (p == 1.0) {
    return self.sgn() * grad;
  } else if (p == 2.0) {
    return grad * (self / norm).masked_fill_(norm == 0, 0);
  } else if (std::isinf(p)) {
    // Derivative of amax(abs(self), dim, keepdim) but respecting nans
    // We create a mask of `argmax`: it's argmax if self.abs() == norm or it's
    // NaN
    auto self_abs = self.abs();
    auto mask = self_abs.eq(norm).logical_or(self_abs.isnan());
    return self.sgn() * ((grad / mask.sum(dim, true)) * mask);
  } else if (p < 1.0) {
    self_scaled =
        self.sgn() * self.abs().pow_(p - 1).masked_fill_(self == 0, 0);
    return self_scaled * grad * norm.pow(1 - p);
  } else if (p < 2.0) {
    self_scaled = self.sgn() * self.abs().pow_(p - 1);
    scale_v = grad / norm.pow(p - 1);
    scale_v.masked_fill_(norm == 0, 0);
    return self_scaled * scale_v;
  } else {
    self_scaled = self * self.abs().pow_(p - 2);
    scale_v = grad / norm.pow(p - 1);
    scale_v.masked_fill_(norm == 0, 0);
    return self_scaled * scale_v;
  }
}

// See norm_backward above for a note on ignoring the sanitizer
Tensor norm_jvp(
    const Tensor& self_p,
    const Tensor& self_t,
    const std::optional<Scalar>& p_,
    Tensor norm,
    IntArrayRef dim,
    bool keepdim) {
  // NB: currently norm_jvp is also reused for dist's jvp (which haas two
  // differentiable inputs)
  //     but self_t still cannot be a ZT because that would require both self_t
  //     and other_t to be ZT
  TORCH_INTERNAL_ASSERT(!self_t._is_zerotensor());
  size_t ndim = self_p.dim(); // composite compliance?
  double p = p_.value_or(2.0).toDouble();

  if (p == 0.0) {
    return at::zeros_like(norm);
  } else if (p == 1.0) {
    auto result = self_p.sgn();
    result = areAnyTensorSubclassLike({self_t}) ? result.mul(self_t.conj())
                                                : result.mul_(self_t.conj());
    result = at::real(result);
    return result.sum(dim, keepdim);
  } else if (p == 2.0) {
    auto result = self_p.mul(self_t.conj());
    result = at::real(result);
    result = result.sum(dim, keepdim);
    return result.div_(norm).masked_fill_(norm == 0, 0);
  } else if (std::isinf(p)) {
    if (!keepdim && self_p.dim() != 0) {
      norm = unsqueeze_multiple(norm, dim, ndim);
    }
    const auto self_isnan = self_p.isnan();
    const auto norm_isnan = norm.isnan();
    const auto& self_and_norm_isnan = areAnyTensorSubclassLike({norm})
        ? self_isnan.logical_and(norm_isnan)
        : self_isnan.logical_and_(norm_isnan);
    const auto is_eq_max =
        (self_p.abs() == norm).logical_or_(self_and_norm_isnan).type_as(norm);
    auto nb_max = is_eq_max.count_nonzero(dim);
    if (self_p.dim() != 0) {
      nb_max = unsqueeze_multiple(nb_max, dim, ndim);
    }
    return (at::real(self_p.sgn() * self_t.conj()) * is_eq_max / nb_max)
        .sum(dim, keepdim);
  } else if (p < 1.0) {
    auto sumpow_t = (self_p.abs().pow_(p - 1).masked_fill_(self_p == 0, 0) *
                     at::real(self_p.sgn() * self_t.conj()))
                        .sum(dim, keepdim);
    return sumpow_t * norm.pow(1 - p);
  } else if (p < 2.0) {
    auto sumpow_t =
        (self_p.abs().pow_(p - 1) * at::real(self_p.sgn() * self_t.conj()))
            .sum(dim, keepdim);
    auto out = sumpow_t / norm.pow(p - 1);
    return out.masked_fill_(norm == 0, 0);
  } else {
    auto sumpow_t =
        (self_p.abs().pow_(p - 2) * at::real(self_p * self_t.conj()))
            .sum(dim, keepdim);
    auto out = sumpow_t / norm.pow(p - 1);
    return out.masked_fill_(norm == 0, 0);
  }
}

Tensor norm_jvp(
    const Tensor& self_p,
    const Tensor& self_t,
    const std::optional<Scalar>& p_,
    Tensor norm) {
  return norm_jvp(self_p, self_t, p_, std::move(norm), {}, true);
}

Tensor _nested_from_padded_backward(
    const Tensor& grad,
    const Tensor& input,
    bool do_transform_0213) {
  if (do_transform_0213) {
    auto new_sizes = {
        input.size(0), input.size(2), (input.size(1) * input.size(3))};
    auto out = grad.to_padded_tensor(0, new_sizes);
    auto expand_last_dim_size = {
        input.size(0), input.size(2), input.size(1), input.size(3)};
    return out.view(expand_last_dim_size).permute({0, 2, 1, 3});
  }
  return grad.to_padded_tensor(0, input.sizes());
}

std::tuple<Tensor, Tensor, Tensor> linear_double_backward(
    const variable_list& grads,
    const Tensor& self,
    const Tensor& grad_output,
    const Tensor& weight) {
  if (!grad_output.defined()) {
    return std::make_tuple(Tensor(), Tensor(), Tensor());
  }

  Tensor grad_self, grad_grad_output, grad_weight;

  if (grads[1].defined()) {
    grad_self =
        (grad_output.dim() == 1 ? grad_output.unsqueeze(0) : grad_output)
            .matmul(grads[1]);
    if (grad_output.dim() == 1) {
      grad_self = grad_self.squeeze(0);
    }
  }
  if (grads[0].defined()) {
    grad_weight =
        (grad_output.dim() == 1 ? grad_output.unsqueeze(1) : grad_output.mT())
            .matmul(grads[0].dim() == 1 ? grads[0].unsqueeze(0) : grads[0]);
  }

  if (grads[0].defined() || grads[1].defined() || grads[2].defined()) {
    grad_grad_output = at::zeros_like(grad_output);
    if (grad_output.dim() == 1) {
      grad_grad_output = grad_grad_output.unsqueeze(0);
    }
  }

  if (grads[0].defined()) {
    grad_grad_output = grad_grad_output +
        (grads[0].dim() == 1 ? grads[0].unsqueeze(0) : grads[0])
            .matmul(weight.mT());
  }
  if (grads[1].defined()) {
    grad_grad_output = grad_grad_output +
        (self.dim() == 1 ? self.unsqueeze(0) : self).matmul(grads[1].mT());
  }
  if (grads[2].defined()) {
    grad_grad_output = grad_grad_output + grads[2];
  }
  if (grad_grad_output.defined() && grad_output.dim() == 1) {
    grad_grad_output = grad_grad_output.squeeze(0);
  }

  return std::make_tuple(
      std::move(grad_self),
      std::move(grad_grad_output),
      std::move(grad_weight));
}

Tensor linalg_vector_norm_jvp(
    const Tensor& self_p,
    const Tensor& self_t,
    const Scalar& scalar_ord,
    Tensor norm,
    const at::OptionalIntArrayRef& opt_dim,
    bool keepdim) {
  // No need to handle the dtype arg as it's handled via broadcasting in the
  // function
  auto dim = opt_dim.value_or(IntArrayRef({}));
  return norm_jvp(self_p, self_t, scalar_ord, std::move(norm), dim, keepdim);
}

Tensor linalg_vector_norm_backward(
    Tensor grad,
    const Tensor& self,
    const Scalar& scalar_ord,
    Tensor norm,
    const at::OptionalIntArrayRef& opt_dim,
    bool keepdim) {
  // No need to handle the dtype arg as it's handled via broadcasting in the
  // function
  auto dim = opt_dim.value_or(IntArrayRef({}));
  return norm_backward(
      std::move(grad), self, scalar_ord, std::move(norm), dim, keepdim);
}

Tensor pow_backward(Tensor grad, const Tensor& self, const Scalar& exponent) {
  if (exponent.equal(0.0)) {
    return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  } else {
    auto grad_lambda = [&](auto exp) {
      return grad * (exp * self.pow(exp - 1)).conj();
    };
    Tensor out = (exponent.isComplex())
        ? grad_lambda(exponent.toComplexDouble())
        : grad_lambda(exponent.toDouble());
    return handle_r_to_c(self, std::move(out));
  }
}

Tensor pow_backward_self(
    const Tensor& grad,
    const Tensor& self,
    const Tensor& exponent) {
  auto out = at::where(
      exponent == 0.0,
      at::zeros({}, grad.options()),
      grad * (exponent * self.pow(exponent - 1)).conj());
  return handle_r_to_c(self, std::move(out));
}

// Caveats:
// We define d(a^b)/db at a = 0 and b < 0 to be -inf. This is due to
// d(a^b)/db -> -inf for a fixed b as a -> +0
// Currently, tensorflow defines d(a^b)/db = nan for a = 0 and b < 0.
//
// We define d(a^b)/db = 0 for a = 0 and b = 0 by continuity as
// d(a^b)/db = 0 for a > 0 and b -> +0.
// Currently, tensorflow agrees with us.
Tensor pow_backward_exponent(
    const Tensor& grad,
    const Tensor& self,
    const Tensor& exponent,
    const Tensor& result) {
  Tensor cond;
  if (exponent.is_complex()) {
    auto is_real_exp =
        at::logical_and(at::imag(exponent) == 0, at::real(exponent) >= 0);
    cond = at::logical_and(self == 0, is_real_exp);
  } else {
    cond = at::logical_and(self == 0, exponent >= 0);
  }
  auto promoted_dtype = at::result_type(self, exponent);
  // `.to()` is no-op if dtype is same.
  auto self_ = self.to(promoted_dtype);

  auto out =
      grad *
      at::where(
          cond, at::zeros({}, grad.options()), (result * self_.log()).conj());
  return handle_r_to_c(exponent, std::move(out));
}

Tensor pow_backward_exponent(
    const Tensor& grad,
    const Scalar& base,
    const Tensor& exponent,
    const Tensor& result) {
  auto grad_lambda = [](const Tensor& a, const Scalar& b) {
    return (a * b.log()).conj();
  };
  auto base_ = exponent.is_complex() && !base.isComplex()
      ? base.toComplexDouble()
      : base;
  if (base.equal(0.0)) {
    auto cond = [](auto exp) {
      if (exp.is_complex()) {
        return at::logical_and(at::imag(exp) == 0, at::real(exp) >= 0);
      } else {
        return exp >= 0;
      }
    };
    auto out = grad *
        at::where(cond(exponent),
                  at::zeros({}, grad.options()),
                  grad_lambda(result, base_));
    return handle_r_to_c(exponent, std::move(out));
  } else {
    auto out = grad * grad_lambda(result, base_);
    return handle_r_to_c(exponent, std::move(out));
  }
}

Tensor angle_backward(const Tensor& grad, const Tensor& self) {
  if (self.is_complex()) {
    return at::where(
        self == 0.0,
        at::zeros({}, self.options()),
        grad * self / self.abs().pow(2) *
            Scalar(c10::complex<double>{0.0, 1.0}));
  } else {
    return at::zeros_like(self, at::MemoryFormat::Preserve);
  }
}

Tensor mvlgamma_backward(const Tensor& grad, const Tensor& self, int64_t p) {
  Tensor args =
      at::arange(-static_cast<double>(p) / 2. + 0.5, 0.5, 0.5, self.options());
  args = args.add(self.unsqueeze(-1));
  return grad * args.digamma_().sum(-1);
}

Tensor sgn_backward(const Tensor& x, const Tensor& gx, const Tensor& sgn) {
  if (x.is_complex()) {
    auto abs = x.abs();
    return ((gx - (sgn * sgn) * gx.conj()) / (2. * abs))
        .masked_fill_(abs == 0., 0.);
  } else {
    return at::_efficientzerotensor(sgn.sizes(), sgn.options());
  }
}

Tensor masked_fill_backward(const Tensor& grad, const Tensor& mask) {
  // masked_select does not work well with functorch, as its shape is
  // data-dependent
  return areAnyTensorSubclassLike({grad, mask})
      ? at::where(mask, grad, 0).sum()
      : grad.masked_select(mask).sum();
}

template <typename T>
Tensor mul_tensor_backward(const Tensor& grad, T other, ScalarType self_st) {
  auto out = grad * other.conj();
  return handle_r_to_c(self_st, std::move(out));
}
template Tensor mul_tensor_backward(const Tensor&, Tensor, ScalarType);
template Tensor mul_tensor_backward(const Tensor&, Scalar, ScalarType);

template <typename T>
Tensor div_tensor_self_backward(
    const Tensor& grad,
    T other,
    ScalarType self_st,
    const std::optional<c10::string_view>& rounding_mode) {
  if (rounding_mode.has_value()) {
    return at::zeros_like(grad, grad.options().dtype(self_st));
  }

  auto result = grad / other.conj();
  return handle_r_to_c(self_st, std::move(result));
}
template Tensor div_tensor_self_backward(
    const Tensor&,
    Tensor,
    ScalarType,
    const std::optional<c10::string_view>&);
template Tensor div_tensor_self_backward(
    const Tensor&,
    Scalar,
    ScalarType,
    const std::optional<c10::string_view>&);

template <typename T>
Tensor div_tensor_self_backward(
    const Tensor& grad,
    T other,
    ScalarType self_st) {
  return div_tensor_self_backward(
      grad, std::move(other), self_st, std::nullopt);
}
template Tensor div_tensor_self_backward(const Tensor&, Tensor, ScalarType);
template Tensor div_tensor_self_backward(const Tensor&, Scalar, ScalarType);

Tensor div_tensor_other_backward(
    const Tensor& grad,
    const Tensor& self,
    const Tensor& other,
    const std::optional<c10::string_view>& rounding_mode) {
  if (rounding_mode.has_value()) {
    return at::zeros_like(grad, grad.options().dtype(other.scalar_type()));
  }

  auto result = -grad * ((self / other) / other).conj();
  return handle_r_to_c(other, std::move(result));
}

Tensor div_tensor_other_backward(
    const Tensor& grad,
    const Tensor& self,
    const Tensor& other) {
  return div_tensor_other_backward(grad, self, other, std::nullopt);
}

Tensor permute_backwards(const Tensor& grad, IntArrayRef fwd_dims) {
  // invert the permutation
  auto ndims = fwd_dims.size();
  std::vector<int64_t> dims(ndims);
  for (const auto i : c10::irange(ndims)) {
    dims[at::maybe_wrap_dim(fwd_dims[i], static_cast<int64_t>(ndims))] =
        static_cast<int64_t>(i);
  }
  return grad.permute(dims);
}

Tensor rad2deg_backward(const Tensor& grad) {
  constexpr double M_180_PI =
      57.295779513082320876798154814105170332405472466564;
  return at::mul(grad, Scalar(M_180_PI));
}

Tensor deg2rad_backward(const Tensor& grad) {
  constexpr double M_PI_180 =
      0.017453292519943295769236907684886127134428718885417;
  return at::mul(grad, Scalar(M_PI_180));
}

Tensor unsqueeze_multiple(
    const Tensor& t,
    OptionalIntArrayRef opt_dim,
    size_t n_dims) {
  if (opt_dim.has_value()) {
    IntArrayRef dim = opt_dim.value();
    auto dim_size = dim.size();
    // Optimisation for two common cases
    if (dim_size == 0) {
      return t;
    } else if (dim_size == 1) {
      return t.unsqueeze(dim[0]);
    }
  }
  auto dims_to_unsqueeze = at::dim_list_to_bitset(opt_dim, n_dims);
  Tensor res = t;
  for (const auto i : c10::irange(n_dims)) {
    if (dims_to_unsqueeze[i]) {
      res = res.unsqueeze(static_cast<int64_t>(i));
    }
  }
  return res;
}

Tensor sum_backward(
    const Tensor& grad,
    c10::SymIntArrayRef sizes,
    OptionalIntArrayRef opt_dims,
    bool keepdim) {
  if (!keepdim && !sizes.empty()) {
    if (opt_dims.has_value() && !opt_dims.value().empty()) {
      return unsqueeze_multiple(grad, opt_dims, sizes.size())
          .expand_symint(sizes);
    }
  }
  return grad.expand_symint(sizes);
}

Tensor sum_backward(
    const Tensor& grad,
    c10::SymIntArrayRef sizes,
    c10::IntArrayRef dims,
    bool keepdim) {
  if (!keepdim && !sizes.empty() && !dims.empty()) {
    // we are only using `keepdim=true` path for SymInts for now
    TORCH_CHECK_NOT_IMPLEMENTED(
        false,
        "Only the keepdim=true path is implemented to support symints in autograd");
  } else {
    return grad.expand_symint(sizes);
  }
}

Tensor nansum_backward(
    const Tensor& grad,
    const Tensor& self,
    at::OptionalIntArrayRef dims,
    bool keepdim) {
  return sum_backward(grad, self.sym_sizes(), dims, keepdim) *
      self.isnan().logical_not();
}

Tensor mean_backward(
    const Tensor& grad,
    c10::SymIntArrayRef shape,
    OptionalIntArrayRef opt_dim,
    c10::SymInt numel,
    bool keepdim) {
  bool is_all_reduce = !opt_dim.has_value() || opt_dim.value().empty();
  auto n =
      is_all_reduce ? std::move(numel) : _safe_size(shape, opt_dim.value());
  return sum_backward(grad, shape, opt_dim, keepdim) / std::move(n);
}

std::vector<c10::SymInt> reverse_list_symint(const c10::SymIntArrayRef list) {
  auto result = std::vector<c10::SymInt>();
  result.reserve(list.size());
  for (auto iter = list.rbegin(); iter != list.rend(); iter++) {
    result.push_back(*iter);
  }
  return result;
}

std::vector<int64_t> reverse_list(const IntArrayRef list) {
  auto result = std::vector<int64_t>();
  result.reserve(list.size());
  for (auto iter = list.rbegin(); iter != list.rend(); iter++) {
    result.push_back(*iter);
  }
  return result;
}

Tensor prod_safe_zeros_backward(
    const Tensor& grad,
    const Tensor& inp,
    int64_t dim) {
  if (inp.sym_numel() == 0) {
    // When input has a zero sized dimension (empty tensor),
    // we don't need to actually compute the grads.
    // So we just reshape `grad` as `input`.
    return grad.expand_as(inp);
  }

  if (inp.sym_size(dim) == 1) {
    return grad;
  }

  auto ones_size = inp.sym_sizes().vec();
  ones_size[dim] = 1;
  Tensor ones = at::ones_symint(ones_size, grad.options());
  Tensor exclusive_normal_nocp =
      at::cat({ones, inp.narrow_symint(dim, 0, inp.sym_size(dim) - 1)}, dim);
  Tensor exclusive_normal = exclusive_normal_nocp.cumprod(dim);

  Tensor narrow_reverse =
      inp.narrow_symint(dim, 1, inp.sym_size(dim) - 1).flip(dim);
  Tensor exclusive_reverse_nocp =
      at::cat({std::move(ones), std::move(narrow_reverse)}, dim);
  Tensor exclusive_reverse = exclusive_reverse_nocp.cumprod(dim).flip(dim);

  return grad * (exclusive_normal * exclusive_reverse).conj();
}

// note that the gradient for prod is equivalent to:
// cumprod(exclusive, normal) * cumprod(exclusive, reverse), e.g.:
// input:                        [    a,     b,     c]
// cumprod(exclusive, normal):   [1    ,     a, a * b]
// cumprod(exclusive, reverse):  [b * c,     c,     1]
// product:                      [b * c, a * c, a * b]
// and this is safe under input with 0s.
Tensor prod_backward(
    const Tensor& grad,
    const Tensor& input,
    const Tensor& result) {
  if (input.dim() == 0) {
    return grad;
  }
  if (input.is_meta() || isTensorSubclassLike(input)) {
    // For Composite Compliance, always take the safer (and slower) path
    return prod_safe_zeros_backward(grad, input.contiguous().view(-1), 0)
        .view_as(input);
  }
  Tensor zero_idx = (input == 0).nonzero();
  if (zero_idx.sym_numel() == 0) {
    return grad * (result / input).conj();
  } else if (!at::GradMode::is_enabled() && zero_idx.sym_size(0) > 1) {
    return at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  } else {
    return prod_safe_zeros_backward(grad, input.contiguous().view(-1), 0)
        .view_as(input);
  }
}

Tensor prod_backward(
    Tensor grad,
    const Tensor& input,
    Tensor result,
    int64_t dim,
    bool keepdim) {
  if (input.dim() == 0) {
    return grad;
  }
  dim = at::maybe_wrap_dim(dim, static_cast<int64_t>(input.sym_sizes().size()));
  if (!keepdim) {
    // `prod` reduces the dimension at `dim`,
    // so, unsqueeze `grad` and `result` at dim.
    grad = grad.unsqueeze(dim);
    result = result.unsqueeze(dim);
  }
  if (input.is_meta() || isTensorSubclassLike(input)) {
    // For Composite Compliance, always take the safer (and slower) path
    return prod_safe_zeros_backward(grad, input, dim);
  }

  Tensor zero_mask = (input == 0);
  Tensor slice_zero_count = zero_mask.sum(dim, true);
  int64_t total_zeros = slice_zero_count.sum().item<int64_t>();
  if (total_zeros == 0) {
    return grad * (result / input).conj();
  } else {
    return prod_safe_zeros_backward(grad, input, dim);
  }
}

template <typename solve_f>
static Tensor generic_solve_jvp(
    solve_f solve,
    const Tensor& X,
    const Tensor& A,
    const Tensor& dA,
    const Tensor& dB) {
  auto is_vector_case = at::native::linalg_solve_is_vector_rhs(dA, dB);
  auto dA_contrib =
      is_vector_case ? dA.matmul(X.unsqueeze(-1)).squeeze(-1) : dA.matmul(X);
  // In general,
  // dX = solve(A, dB - dA_contrib), but this behavior is different for
  // lu_solve. For refer to lu_solve_jvp for more details on this.
  return solve(A, dB, dA_contrib);
}

Tensor cumsum_backward(const Tensor& grad, int64_t dim) {
  // Trivial case
  if (grad.sym_numel() <= 1 || grad.sym_size(dim) == 1) {
    return grad;
  }
  return grad.flip(dim).cumsum(dim).flip(dim);
}

Tensor logsumexp_backward(
    Tensor grad,
    const Tensor& self,
    Tensor result,
    IntArrayRef dim,
    bool keepdim) {
  if (!keepdim && self.dim() != 0) {
    grad = unsqueeze_multiple(grad, dim, self.sym_sizes().size());
    result = unsqueeze_multiple(result, dim, self.sym_sizes().size());
  }
  return grad * (self - result).exp();
}

Tensor logcumsumexp_backward(
    Tensor grad,
    const Tensor& self,
    Tensor result,
    int64_t dim) {
  if (grad.dim() == 0 || grad.sym_numel() == 0) {
    return grad;
  }

  // Reference: https://github.com/tensorflow/tensorflow/blob/
  // 2a5910906a0e0f3dbc186ff9db6386d81a63448c/tensorflow/python/ops/math_grad.py#L1832-L1863

  auto scalar_min = AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(
      at::ScalarType::BFloat16,
      at::typeMetaToScalarType(grad.dtype()),
      "logcumsumexp_backward",
      []() { return c10::Scalar(std::numeric_limits<scalar_t>::lowest()); });

  auto reverse_logcumsumexp = [dim](auto x) {
    return at::flip(at::logcumsumexp(at::flip(x, {dim}), dim), {dim});
  };

  if (!at::is_complex(grad)) {
    auto grad_min = at::scalar_tensor(scalar_min, grad.options());
    auto log_abs_grad = grad.abs().log();
    auto log_grad_positive = at::where(grad > 0, log_abs_grad, grad_min);
    auto log_grad_negative = at::where(grad < 0, log_abs_grad, grad_min);

    auto output_pos =
        (reverse_logcumsumexp(log_grad_positive - result) + self).exp();
    auto output_neg =
        (reverse_logcumsumexp(log_grad_negative - result) + self).exp();

    return output_pos - output_neg;
  } else {
    // no trick separating the positive and negative required
    auto log_grad = grad.conj().log();
    auto output = (reverse_logcumsumexp(log_grad - result) + self).exp();
    return output.conj();
  }
}

Tensor logcumsumexp_jvp(
    const Tensor& self_p,
    const Tensor& self_t,
    int64_t dim) {
  // Mostly taken from logsumexp_jvp

  // NB: for simplicity, we recompute some values that can be reused from
  // forward
  auto self_p_exp = [&self_p, dim]() {
    if (!at::is_complex(self_p)) {
      return (self_p - std::get<0>(at::max(self_p, dim, true)))
          .exp(); // Use the exp-normalize trick
    } else {
      // at::max doesn't support complex128
      return self_p.exp();
    }
  }();

  auto cumsumexp_p = self_p_exp.cumsum(dim);

  TORCH_INTERNAL_ASSERT(!self_t._is_zerotensor())

  constexpr double eps = 1e-13;

  if (areAnyTensorSubclassLike({self_p, self_t})) {
    auto result = (self_p_exp * self_t).cumsum(dim);
    result /= cumsumexp_p.add_(eps);
    return result;
  } else {
    self_p_exp *= self_t;
    auto cumsumexp_t = self_p_exp.cumsum(dim);
    return cumsumexp_t /= cumsumexp_p.add_(eps);
  }
}

Tensor unbind_backward(const variable_list& grads, int64_t dim) {
  c10::SymIntArrayRef sizes;
  at::TensorOptions o;
  for (const auto& v : grads) {
    if (v.defined()) {
      sizes = v.sym_sizes();
      o = static_cast<Tensor>(v).options();
      break;
    }
  }
  auto grads_tensors = fmap(grads, [&](const Variable& v) {
    return (
        v.defined() ? static_cast<Tensor>(v)
                    : at::zeros({}, o).expand_symint(sizes));
  });
  return at::stack(grads_tensors, dim);
}

Tensor unbind_backward_nested(
    const variable_list& grads,
    const Tensor& nt_sizes,
    int64_t dim,
    const at::TensorOptions& options) {
  std::vector<Tensor> grads_tensors;
  for (int64_t i : c10::irange(static_cast<int64_t>(grads.size()))) {
    if (grads[i].defined()) {
      grads_tensors.push_back(static_cast<Tensor>(grads[i]));
    } else {
      const auto component_size = nt_sizes[i].contiguous();
      const c10::IntArrayRef grad_size(
          component_size.data_ptr<int64_t>(), component_size.size(0));
      grads_tensors.push_back(at::zeros(grad_size, options));
    }
  }

  return at::_nested_tensor_from_tensor_list(grads_tensors);
}

Tensor unbind_backward_nested_jagged(
    const variable_list& grads,
    const Tensor& self,
    int64_t dim) {
  TORCH_INTERNAL_ASSERT(
      dim == 0, "unbind_backward_nested_jagged() only supports dim=0")
  auto grad_nt = at::zeros_like(self);
  auto unbound_grads = grad_nt.unbind();
  for (int64_t i : c10::irange(static_cast<int64_t>(grads.size()))) {
    if (grads[i].defined()) {
      unbound_grads[i].copy_(static_cast<Tensor>(grads[i]));
    }
  }

  return grad_nt;
}

Tensor unsqueeze_to(const Tensor& self, c10::SymIntArrayRef sym_sizes) {
  auto result = self;

  auto nDims = sym_sizes.size();
  for (const auto dim : c10::irange(nDims)) {
    if (sym_sizes[dim] == 1) {
      result = result.unsqueeze(static_cast<int64_t>(dim));
    }
  }
  return result;
}

Tensor unsqueeze_to(
    const Tensor& self,
    IntArrayRef dims,
    c10::SymIntArrayRef sym_sizes) {
  const auto ndim = sym_sizes.size();
  auto mask = at::dim_list_to_bitset(dims, ndim);

  Tensor result = self;
  for (const auto d : c10::irange(ndim)) {
    if (mask.test(d) && sym_sizes[d] == 1) {
      result = result.unsqueeze(static_cast<int64_t>(d));
    }
  }
  return result;
}

Tensor unsqueeze_to(
    const Tensor& self,
    int64_t dim,
    c10::SymIntArrayRef sym_sizes) {
  return unsqueeze_to(self, IntArrayRef{dim}, sym_sizes);
}

std::vector<Tensor> cat_tensors_backward(
    const Tensor& grad,
    const std::vector<std::vector<c10::SymInt>>& sizes,
    const std::vector<ScalarType>& dtypes,
    int64_t dim) {
  std::vector<Tensor> grad_inputs(sizes.size());
  if (!grad.defined()) {
    return grad_inputs;
  }
  dim = at::legacy_cat_wrap_dim_symint(dim, sizes);
  c10::SymInt accumulate = 0;

  Tensor grad_;
  bool grad_is_complex = grad.is_complex();
  if (grad_is_complex) {
    grad_ = at::real(grad);
  }
  for (const auto i : c10::irange(sizes.size())) {
    Tensor grad_val;
    if (!at::isComplexType(dtypes[i]) && grad_is_complex) {
      // R -> C
      grad_val = grad_;
    } else {
      grad_val = grad;
    }
    auto& shape = sizes[i];
    // If input was empty tensor, gradInput should be empty tensor.
    if (shape.size() == 1) {
      if (TORCH_GUARD_SIZE_OBLIVIOUS(shape[0].sym_eq(0))) {
        grad_inputs[i] = at::zeros({0}, grad_val.options());
        continue;
      }
    }
    const auto& size = shape[dim];
    accumulate += size;
    grad_inputs[i] = grad_val.narrow_symint(dim, accumulate - size, size);
  }
  return grad_inputs;
}

std::vector<Tensor> stack_tensors_backward(
    const Tensor& grad,
    int64_t dim,
    const std::vector<ScalarType>& dtypes) {
  std::vector<Tensor> grad_inputs(dtypes.size());
  if (!grad.defined()) {
    return grad_inputs;
  }
  bool grad_is_complex = grad.is_complex();
  for (const auto i : c10::irange(dtypes.size())) {
    auto gr = grad.select(dim, static_cast<int64_t>(i));
    if (grad_is_complex && !at::isComplexType(dtypes[i])) {
      gr = at::real(gr);
    }
    grad_inputs[i] = gr;
  }
  return grad_inputs;
}

std::vector<Tensor> block_diag_backward(
    const Tensor& grad,
    const std::vector<std::vector<int64_t>>& sizes,
    const std::vector<ScalarType>& dtypes) {
  std::vector<Tensor> grad_inputs(sizes.size());
  if (!grad.defined()) {
    return grad_inputs;
  }
  Tensor real_view_of_grad;
  bool grad_is_complex = grad.is_complex();
  if (grad_is_complex) {
    real_view_of_grad = at::real(grad);
  }

  int64_t cur_dim0 = 0;
  int64_t cur_dim1 = 0;

  for (const auto i : c10::irange(sizes.size())) {
    // R -> C
    Tensor grad_val = (!at::isComplexType(dtypes[i]) && grad_is_complex)
        ? real_view_of_grad
        : grad;

    auto& shape = sizes[i];
    // If input was empty tensor, gradInput should be empty tensor.
    if (shape.size() == 1 && shape[0] == 0) {
      grad_inputs[i] = at::zeros({0}, grad_val.options());
      continue;
    }
    // 0d case
    int64_t dim0 = 1;
    int64_t dim1 = 1;
    // 2d case
    if (shape.size() == 2) {
      dim0 = shape[0];
      dim1 = shape[1];
      // 1d case
    } else if (shape.size() == 1) {
      dim1 = shape[0];
    }
    auto slice = grad_val.slice(0, cur_dim0, cur_dim0 + dim0)
                     .slice(1, cur_dim1, cur_dim1 + dim1);
    if (shape.size() == 1) {
      slice = slice.squeeze(-1);
    } else if (shape.empty()) {
      slice = slice.squeeze(-1).squeeze(-1);
    }
    grad_inputs[i] = slice;
    cur_dim0 += dim0;
    cur_dim1 += dim1;
  }
  return grad_inputs;
}

Tensor clamp_backward(
    const Tensor& grad,
    const Tensor& self,
    const std::optional<Scalar>& min,
    const std::optional<Scalar>& max) {
  // clamp: gradients not defined on min and max, so we return the subgradient 1
  // for these cases.
  if (max && min) {
    auto zero = at::scalar_tensor(0., grad.options());
    return where((self >= *min).logical_and_(self <= *max), grad, zero);
  } else if (min) {
    auto zero = at::scalar_tensor(0., grad.options());
    return where(self >= *min, grad, zero);
  } else if (max) {
    auto zero = at::scalar_tensor(0., grad.options());
    return where(self <= *max, grad, zero);
  } else {
    return grad;
  }
}

Tensor clamp_backward(
    const Tensor& grad,
    const Tensor& self,
    const Tensor& min,
    const Tensor& max) {
  // clamp: gradients not defined on min and max, so we return the subgradient 1
  // for these cases.
  if (max.defined() && min.defined()) {
    auto zero = at::scalar_tensor(0., grad.options());
    const auto self_ge_min = self >= min;
    const auto self_le_max = self <= max;
    const auto& pred = areAnyTensorSubclassLike({self, min, max})
        ? self_ge_min.logical_and(self_le_max)
        : self_ge_min.logical_and_(self_le_max);
    return where(pred, grad, zero);
  } else if (min.defined()) {
    auto zero = at::scalar_tensor(0., grad.options());
    return where(self >= min, grad, zero);
  } else if (max.defined()) {
    auto zero = at::scalar_tensor(0., grad.options());
    return where(self <= max, grad, zero);
  } else {
    return grad;
  }
}

std::tuple<at::Tensor, at::Tensor> clamp_backward_min_max(
    const Tensor& grad,
    const Tensor& self,
    const Tensor& min,
    const Tensor& max,
    const std::array<bool, 2>& grad_input_mask) {
  // If min > max, min has no gradient
  std::tuple<at::Tensor, at::Tensor> ret;
  if (!grad.defined()) {
    return ret;
  }

  auto zero = at::scalar_tensor(0., grad.options());
  if (max.defined() && min.defined()) {
    if (grad_input_mask[0]) {
      const auto self_lt_min = self < min;
      const auto min_lt_max = min < max;
      const auto& pred = areAnyTensorSubclassLike({self, min, max})
          ? self_lt_min.logical_and(min_lt_max)
          : self_lt_min.logical_and_(min_lt_max);
      std::get<0>(ret) = where(pred, grad, zero);
    }
    if (grad_input_mask[1]) {
      const auto self_gt_max = self > max;
      const auto max_lt_min = max < min;
      const auto& pred = areAnyTensorSubclassLike({self, min, max})
          ? self_gt_max.logical_or(max_lt_min)
          : self_gt_max.logical_or_(max_lt_min);
      std::get<1>(ret) = where(pred, grad, zero);
    }
  } else if (min.defined() && grad_input_mask[0]) {
    std::get<0>(ret) = where(self < min, grad, zero);
  } else if (max.defined() && grad_input_mask[1]) {
    std::get<1>(ret) = where(self > max, grad, zero);
  }
  return ret;
}

at::Tensor clamp_jvp(
    const Tensor& self_p,
    const Tensor& self_t,
    const Tensor& min_p,
    const Tensor& min_t,
    const Tensor& max_p,
    const Tensor& max_t) {
  if (min_p.defined() && max_p.defined()) {
    return where(
        min_p > max_p,
        max_t,
        where(self_p < min_p, min_t, where(self_p > max_p, max_t, self_t)));
  } else if (min_p.defined()) {
    return where(self_p > min_p, self_t, min_t);
  } else if (max_p.defined()) {
    return where(self_p < max_p, self_t, max_t);
  } else {
    return self_t;
  }
}

Tensor convolution_jvp(
    const Tensor& input_p,
    const Tensor& input_t,
    const Tensor& weight_p,
    const Tensor& weight_t,
    const Tensor& bias_p,
    const Tensor& bias_t,
    at::SymIntArrayRef stride,
    at::SymIntArrayRef padding,
    at::SymIntArrayRef dilation,
    bool transposed,
    at::SymIntArrayRef output_padding,
    const c10::SymInt& groups) {
  auto bias_t_opt =
      bias_t.defined() ? std::optional<at::Tensor>(bias_t) : std::nullopt;
  return (
      at::convolution_symint(
          input_t,
          weight_p,
          std::nullopt,
          stride,
          padding,
          dilation,
          transposed,
          output_padding,
          groups) +
      at::convolution_symint(
          input_p,
          weight_t,
          bias_t_opt,
          stride,
          padding,
          dilation,
          transposed,
          output_padding,
          groups));
}

Tensor _convolution_jvp(
    const Tensor& input_p,
    const Tensor& input_t,
    const Tensor& weight_p,
    const Tensor& weight_t,
    const Tensor& bias_p,
    const Tensor& bias_t,
    at::SymIntArrayRef stride,
    at::SymIntArrayRef padding,
    at::SymIntArrayRef dilation,
    bool transposed,
    at::SymIntArrayRef output_padding,
    const c10::SymInt& groups,
    bool benchmark,
    bool deterministic,
    bool cudnn_enabled,
    bool allow_tf32) {
  auto bias_t_opt =
      bias_t.defined() ? std::optional<at::Tensor>(bias_t) : std::nullopt;
  return (
      at::_convolution_symint(
          input_t,
          weight_p,
          std::nullopt,
          stride,
          padding,
          dilation,
          transposed,
          output_padding,
          groups,
          benchmark,
          deterministic,
          cudnn_enabled,
          allow_tf32) +
      at::_convolution_symint(
          input_p,
          weight_t,
          bias_t_opt,
          stride,
          padding,
          dilation,
          transposed,
          output_padding,
          groups,
          benchmark,
          deterministic,
          cudnn_enabled,
          allow_tf32));
}

Tensor convolution_backward_jvp_grad_bias(
    const Tensor& grad_out_t,
    const Tensor& grad_bias) {
  if (!grad_bias.defined()) {
    return Tensor();
  }
  int64_t dim = grad_out_t.dim() - 2;
  if (dim == 1) {
    // Cannot pass initializer list due to overload ambiguity
    auto dimlist = std::vector<int64_t>{0, 2};
    return grad_out_t.sum(dimlist);
  } else if (dim == 2) {
    return grad_out_t.sum({0, 2, 3});
  } else if (dim == 3) {
    return grad_out_t.sum({0, 2, 3, 4});
  } else {
    TORCH_INTERNAL_ASSERT(
        false,
        "convolution_backward_jvp_grad_bias expected dim of grad_out_t to be 3, 4, or 5, but got: ",
        grad_out_t.dim());
  }
}

// This function is used by load_derivatives.py to replace tensor.strides()
// calls that appear in derivative formulas. If the tensor has requires_grad
// set, this function returns its strides or an empty array if the tensor
// is sparse. If requires_grad is not set, an empty array is returned since
// there will be no backward pass. There has one special case, if input is
// MKLDNN tensor and has requires_grad set, just return an empty array, the
// reason is that MKLDNN tensor is a opaque tensor which has not stride info.
//
// This function only supports the case where `input` is the tensor whose
// single derivative is being calculated.
//
// This function does not support `self` derivatives for inplace functions.
//
// Args:
//  input              Tensor to call .strides() on
//  input_name         Name of `input` tensor, from derivative formula
at::SymIntArrayRef strides_or_error(
    const Tensor& input,
    c10::string_view const& input_name) {
  // TODO: Ideally, this function would never be called if requires_grad is
  // not set. Once codegen is updated to avoid the call, we can remove this
  // check.
  if (input.requires_grad()) {
    if (input.is_mkldnn())
      return {};
    if (input.is_sparse() || at::sparse_csr::is_sparse_compressed(input))
      return {};
    return input.sym_strides();
  } else {
    return {};
  }
}

Tensor mm_mat1_backward(
    const Tensor& grad,
    const Tensor& mat2,
    at::SymIntArrayRef mat1_sizes,
    at::SymIntArrayRef mat1_strides,
    c10::Layout mat1_layout,
    const Scalar& alpha) {
  if (grad.layout() == c10::kStrided && mat2.layout() == c10::kStrided &&
      mat1_layout == c10::kStrided) {
    // if input was column-major, return grad as column-order for efficiency
    if (mat1_strides[0] == 1 && mat1_strides[1] == mat1_sizes[0]) {
      return maybe_multiply(mat2.conj().mm(grad.t()).t(), alpha.conj());
    }
  }

  // General fallback, should work for any layout
  return maybe_multiply(grad.mm(mat2.t().conj()), alpha.conj());
}

Tensor mm_mat2_backward(
    const Tensor& grad,
    const Tensor& mat1,
    at::SymIntArrayRef mat2_sizes,
    at::SymIntArrayRef mat2_strides,
    c10::Layout mat2_layout,
    const Scalar& alpha) {
  if (grad.layout() == c10::kStrided && mat1.layout() == c10::kStrided &&
      mat2_layout == c10::kStrided) {
    // if input was column-major, return grad as column-order for efficiency
    if (mat2_strides[0] == 1 && mat2_strides[1] == mat2_sizes[0]) {
      return maybe_multiply(grad.t().mm(mat1.conj()).t(), alpha.conj());
    }
  }

  // General fallback, should work for any layout
  return maybe_multiply(mat1.t().conj().mm(grad), alpha.conj());
}

Tensor mm_mat1_sparse_backward(
    const Tensor& grad,
    const Tensor& mat1,
    const Tensor& mat2,
    const Scalar& alpha) {
  if (grad.layout() == c10::kStrided && mat2.layout() == c10::kStrided &&
      mat1.is_sparse()) {
    auto sparse = mat1.coalesce();
    Tensor grad_sparse = maybe_multiply(grad.mm(mat2.conj().t()), alpha);
    return grad_sparse.sparse_mask(sparse);
  } else if (
      grad.layout() == c10::kStrided && mat2.layout() == c10::kStrided &&
      mat1.is_sparse_csr()) {
    // zero must to have mat1 sparsity pattern:
    auto zero = mat1.clone();
    zero.values().zero_();
    return at::sparse_sampled_addmm(zero, grad, mat2.mH(), 1.0, alpha);
  } else if (
      grad.layout() == c10::kStrided && mat2.layout() == c10::kStrided &&
      mat1.layout() == c10::kStrided) {
    return maybe_multiply(grad.mm(mat2.mH()), alpha);
  }
  TORCH_CHECK(
      false,
      "sparse_addmm_sparse_backward: unsupported combination of layouts",
      ", grad: ",
      grad.layout(),
      ", mat1: ",
      mat1.layout(),
      ", mat2: ",
      mat2.layout());
}

static Tensor sparse_mask_like_grad(
    const Tensor& x,
    const Tensor& gx,
    bool accumulate_matches) {
  if (x.is_coalesced() && gx.is_coalesced()) {
    if (x._nnz() >= gx._nnz()) {
      // search into x is faster
      return gx._sparse_mask_projection(x, accumulate_matches);
    } else {
      // search into gx is faster
      return gx.sparse_mask(x);
    }
  } else if (x.is_coalesced()) {
    return gx.sparse_mask(x);
  } else if (gx.is_coalesced()) {
    return gx._sparse_mask_projection(x, accumulate_matches);
  } else {
    if (x._nnz() >= gx._nnz()) {
      // gx.coalesce() is likely faster
      return gx.coalesce()._sparse_mask_projection(x, accumulate_matches);
    } else {
      // x.coalesce() is likely faster
      return gx.sparse_mask(x.coalesce());
    }
  }
}

std::tuple<Tensor, Tensor, Tensor> sparse_sampled_addmm_backward(
    const Tensor& grad,
    const Tensor& self,
    const std::optional<Tensor>& mat1,
    const std::optional<Tensor>& mat2,
    const Scalar& alpha,
    const Scalar& beta,
    const std::array<bool, 3>& grad_input_mask) {
  if (!grad.defined()) {
    return std::make_tuple(Tensor{}, Tensor{}, Tensor{});
  }

  const auto grad_projected = grad.sparse_mask(self);
  const auto self_requires_grad = grad_input_mask[0];
  const auto mat1_requires_grad = grad_input_mask[1];
  const auto mat2_requires_grad = grad_input_mask[2];
  return std::make_tuple(
      self_requires_grad ? maybe_multiply(grad, beta.conj()) : Tensor{},
      mat1_requires_grad
          // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
          ? maybe_multiply(grad_projected.mm(mat2->mH()), alpha.conj())
          : Tensor{},
      mat2_requires_grad
          // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
          ? maybe_multiply(mat1->mH().mm(grad_projected), alpha.conj())
          : Tensor{});
}

Tensor sparse_mask_backward(
    const Tensor& grad,
    const Tensor& mask,
    const c10::Layout self_layout) {
  // NOTE: sparse_mask accumulates matches, so the backward step has to
  // accumulate as well.
  const auto self_grad =
      sparse_mask_like_grad(mask, grad, /*accumulate_matches=*/true);
  return self_layout == at::kStrided ? self_grad.to_dense() : self_grad;
}

Tensor sparse_sparse_matmul_backward(
    const Tensor& grad,
    const Tensor& a,
    const Tensor& b,
    int64_t grad_order) {
  /*
  To implement the backward algorithm for sparse matrix-matrix matmul (SPMM) we
  can start from the following definition for dense tensors:

  c = a @ b
      then
  a_grad = c_grad @ b^H
  b_grad = a^H @ c_grad

  So for sparse matrices we can use the following definition:

  if grad_order == 0:
      a_grad = sparse_matrix_mask(c_grad @ b^H, mask=a)
  else:
      b_grad = sparse_matrix_mask(a^H @ c_grad, mask=b)
  */
  TORCH_CHECK(
      grad_order == 0 || grad_order == 1,
      ": grad_order not in [0, 1] at sparse_sparse_matmul_backward function");

  // NOTE: _sparse_sparse_matmul returns a coalesced gradient,
  //   // hence there is no need in accumulating matches.
  if (grad_order == 0) {
    auto a_grad = _sparse_sparse_matmul(grad, b.conj().t());
    return sparse_mask_like_grad(a, a_grad, /*accumulate_matches=*/false);
  }
  auto b_grad = _sparse_sparse_matmul(a.conj().t(), grad);
  return sparse_mask_like_grad(b, b_grad, /*accumulate_matches=*/false);
}

Tensor renorm_backward(
    const Tensor& grad,
    const Tensor& self,
    const Scalar& p,
    int64_t dim,
    const Scalar& maxnorm) {
  auto n = self.dim();
  dim = c10::maybe_wrap_dim(dim, n);
  auto reduce_dims = at::DimVector(n);
  std::iota(reduce_dims.begin(), reduce_dims.end(), 0);
  reduce_dims.erase(reduce_dims.begin() + dim);

  auto acc_type =
      at::toAccumulateType(self.scalar_type(), self.device().type());
  auto norm = at::linalg_vector_norm(
      self, p, reduce_dims, /*keepdim=*/true, /*dtype=*/acc_type);

  const auto real_acc_type = c10::toRealValueType(acc_type);
  auto grad_output = (self.conj() * grad);
  // vector_norm output is real, so grad_output must also be real
  if (real_acc_type != acc_type) {
    grad_output = at::real(grad_output);
  }
  grad_output =
      grad_output.sum(reduce_dims, /*keepdim=*/true, /*dtype=*/real_acc_type);
  auto nb = norm_backward(
      std::move(grad_output), self, p, norm, reduce_dims, /*keepdim=*/true);

  auto invnorm = (norm + 1e-7).reciprocal();
  auto grad_norm = maxnorm * invnorm * (grad - invnorm * nb);
  return at::where(norm > maxnorm, grad_norm.to(grad.scalar_type()), grad);
}

Tensor renorm_jvp(
    const Tensor& self_p,
    const Tensor& self_t,
    const Scalar& p,
    int64_t dim,
    const Scalar& maxnorm) {
  auto self_sizes = self_p.sizes();
  dim = at::maybe_wrap_dim(dim, static_cast<int64_t>(self_sizes.size()));

  at::DimVector reduce_dims(self_sizes.size());
  std::iota(reduce_dims.begin(), reduce_dims.end(), 0);
  reduce_dims.erase(reduce_dims.begin() + dim);

  // For cuda half, calculate norm in float precision then cast
  // normalization factor to half
  auto dtype = self_p.scalar_type();
  auto acc_type = at::toAccumulateType(dtype, /*is_cuda=*/true);
  Tensor norm = [&self_p, &p, &reduce_dims, acc_type, dtype]() {
    if (acc_type != dtype) {
      return at::linalg_vector_norm(
          self_p,
          p.toDouble(),
          reduce_dims,
          /*keepdim=*/true,
          /*dtype=*/acc_type);
    } else {
      return at::linalg_vector_norm(
          self_p,
          p.toDouble(),
          reduce_dims,
          /*keepdim=*/true);
    }
  }();

  auto double_maxnorm = maxnorm.toDouble();
  auto invnorm = (norm + 1e-7).reciprocal();
  auto factor = invnorm * double_maxnorm;

  return where(
      norm > double_maxnorm,
      factor *
          (self_t -
           self_p * invnorm *
               norm_jvp(
                   self_p, self_t, p, norm, reduce_dims, /*keepdim=*/true)),
      self_t);
}

Tensor repeat_backward(
    Tensor grad,
    c10::SymIntArrayRef repeats,
    c10::SymIntArrayRef input_shape) {
  auto find_iter = std::find(repeats.cbegin(), repeats.cend(), 0);
  if (find_iter != repeats.cend()) {
    return at::zeros_symint(input_shape, grad.options());
  }
  const auto input_dims = input_shape.size();
  auto num_unsqueezed = grad.dim() - input_dims;
  for (const auto i : c10::irange(num_unsqueezed)) {
    (void)i; // Suppress unused variable warning
    grad = grad.sum(0, false);
  }

  at::SymDimVector grad_size;
  at::DimVector sum_dims;
  for (const auto dim : c10::irange(input_dims)) {
    const auto& repeat = repeats[dim + num_unsqueezed];
    // Reshape gradient (repeat > 1)
    // Index:      [..., dim    , ...]    [..., dim   ,  dim+1        , ...]
    // Shape: From [..., dimsize, ...] to [..., repeat, dimsize/repeat, ...]
    // The gradient tensor at 'dim' is reshaped to 'repeat' times of input
    // tensor. Then, sum up gradients over repeated tensors along 'dim', and
    // reduce shape from 'repeat * dimsize/repeat' to 'dimsize/repeat'
    // ('input_dimsize'). Example:
    //        Size(3, 2)                                      Size(6, 2)
    //                                                      [[v1_0, v1_1],
    //                                                       [v1_2, v1_3],
    //        [[v0, v1],               repeat(2, 1)          [v1_4, v1_5],
    //         [v2, v3],              ------------->         [v2_0, v2_1],
    //         [v4, v5]]                                     [v2_2, v2_3],
    //                                                       [v2_4, v2_5]]
    //
    //    input grad (3, 2)      reshape (2, 3, 2)         output grad (6, 2)
    //                            [[[g1_0, g1_1],            [[g1_0, g1_1],
    //                              [g1_2, g1_3],             [g1_2, g1_3],
    // [[g1_0+g2_0, g1_1+g2_1],     [g1_4, g1_5]],            [g1_4, g1_5],
    //  [g1_2+g2_2, g1_3+g2_3],     [g2_0, g2_1],            [[g2_0, g2_1],
    //  [g1_4+g2_4, g1_5+g2_5]]     [g2_2, g2_3],             [g2_2, g2_3],
    //                              [g2_4, g2_5]]             [g2_4, g2_5]]]
    //
    // If gradient tensor is reshaped to [..., dimsize/repeat, repeat, ...] and
    // then sum over 'dim+1'. The gradient for input is not correctly aligned
    // with input. Example:
    //  input grad (3, 2)        reshape (3, 2, 2)        output grad (6, 2)
    //                           [[[g1_0, g1_1],           [[g1_0, g1_1],
    //                             [g1_2, g1_3]],           [g1_2, g1_3],
    // [[g1_0+g1_2, g1_1+g1_3],   [[g1_4, g1_5],            [g1_4, g1_5],
    //  [g1_4+g2_0, g1_5+g2_1],    [g2_0, g2_1]],           [g2_0, g2_1],
    //  [g2_2+g2_4, g2_3+g2_5]]   [[g2_2, g2_3],            [g2_2, g2_3],
    //                             [g2_4, g2_5]]]           [g2_4, g2_5]]
    if (repeat != 1) {
      grad_size.push_back(repeat);
      sum_dims.push_back(static_cast<int64_t>(grad_size.size() - 1));
    }
    // Don't need to reshape gradient into (repeat, input_shape[dim]) (repeat ==
    // 1)
    grad_size.push_back(input_shape[dim]);
  }
  // One-time Reshape & Sum
  // Reshape gradient to grad_size:
  //   1. If repeat equals to 1, append input size at that dimension,
  //   2. If repeat is larger than 1, append both repeat and input size at that
  //   dimension.
  // Sum over all "repeat" dimensions from sum_dims:
  // Example:
  // Input Size         (2,    3,    4,    5)
  // repeat             [4,    1,    9,    3]
  // output/grad Size   (8,    3,    36,   15)
  // grad_size          [4, 2,    3, 9, 4, 3, 5]
  // sum_dims           [0,          3,    5]

  // When repeat 1 time over all original dimensions, the empty sum_dims will
  // reduce the whole grad tensor into a scalar rather than keeping original
  // dimensions.
  if (!sum_dims.empty()) {
    grad = grad.reshape_symint(grad_size);
    grad = grad.sum(sum_dims);
  }
  return grad;
}

// p1m == 1 - p
Tensor _fused_dropout_backward(
    const Tensor& grad,
    const Tensor& mask,
    double p1m) {
  if (grad.requires_grad()) {
    // Use autograd-friendly backward if double backward is required
    return grad * (mask.type_as(grad) * (1. / p1m));
  } else {
    return at::_masked_scale(grad, mask, 1. / p1m);
  }
}

// scale == (1 / (1 - prob))
Tensor infinitely_differentiable_native_dropout_backward(
    const Tensor& grad,
    const Tensor& mask,
    double scale) {
  return grad * (mask.type_as(grad) * scale);
}

Tensor native_dropout_double_backward(
    const Tensor& ggI,
    const Tensor& grad,
    const Tensor& mask,
    double scale) {
  return ggI.type_as(grad) * (mask.type_as(grad) * scale);
}

Tensor evenly_distribute_backward(
    const Tensor& grad,
    const Tensor& input,
    const Tensor& value) {
  bool any_tensor_subclass_like =
      areAnyTensorSubclassLike({grad, input, value});
  if (any_tensor_subclass_like || input.is_cuda()) {
    const auto input_isnan = input.isnan();
    const auto value_isnan = value.isnan();
    const auto& input_and_value_isnan = any_tensor_subclass_like
        ? input_isnan.logical_and(value_isnan)
        : input_isnan.logical_and_(value_isnan);
    const auto mask = (input == value).logical_or_(input_and_value_isnan);
    return mask * (grad / mask.sum());
  } else {
    auto mask = value.isnan().item<bool>() ? input.isnan() : input == value;
    return grad.new_zeros(input.sizes(), input.options())
        .masked_fill_(mask, grad / mask.sum());
  }
}

Tensor evenly_read_jvp(
    const Tensor& fw_grad,
    const Tensor& input,
    const Tensor& value) {
  auto mask = (input == value);
  auto count = mask.sum();
  auto grad_output = fw_grad / count;
  return at::sum(mask * grad_output);
}

Tensor var_backward(
    Tensor grad,
    const Tensor& self,
    at::OptionalIntArrayRef dim_opt,
    const std::optional<at::Scalar>& correction_opt,
    bool keepdim) {
  const auto correction = correction_opt.value_or(1).toSymFloat();
  if (self.dim() == 0 || !dim_opt.has_value()) {
    const auto dof = c10::SymFloat(self.sym_numel()) - correction;
    if (dof <= 0) {
      // when n == correction, 2 / (n - correction) is infinity
      // when self == self.mean(), we return NaN because infinity * 0 = NaN
      // otherwise, we return infinity because infinity * c = infinity, for all
      // c > 0
      return grad *
          at::where(
                 self == self.mean(),
                 std::numeric_limits<double>::quiet_NaN(),
                 std::numeric_limits<double>::infinity());
    } else {
      return (c10::SymFloat(2.0) / dof) * grad * (self - self.mean());
    }
  }
  auto dim = dim_opt.value();
  if (!keepdim && self.dim() > 1) {
    grad = unsqueeze_multiple(grad, dim, self.sym_sizes().size());
  }
  const c10::SymFloat rnumel(_safe_size(self.sym_sizes(), dim));
  return (c10::SymFloat(2.0) / (rnumel - correction)) * grad *
      (self - self.mean(dim, /*keepdim=*/true));
}

Tensor std_backward(
    const Tensor& result,
    const Tensor& grad,
    const Tensor& self,
    at::OptionalIntArrayRef dim,
    const std::optional<c10::Scalar>& correction_opt,
    bool keepdim) {
  auto grad_var = (grad / (result * 2)).masked_fill_(result == 0, 0);
  return var_backward(std::move(grad_var), self, dim, correction_opt, keepdim);
}

Tensor var_mean_backward(
    const Tensor& gvar,
    const Tensor& gmean,
    const Tensor& self,
    at::OptionalIntArrayRef dim_opt,
    const std::optional<c10::Scalar>& correction_opt,
    bool keepdim) {
  Tensor gself;
  if (gvar.defined()) {
    gself = var_backward(gvar, self, dim_opt, correction_opt, keepdim);
  }
  if (gmean.defined()) {
    auto aux = mean_backward(
        gmean,
        self.sym_sizes(),
        dim_opt.value_or(IntArrayRef({})),
        self.sym_numel(),
        keepdim);
    gself = gself.defined() ? gself + aux : std::move(aux);
  }
  return gself;
}

Tensor std_mean_backward(
    const Tensor& gstd,
    const Tensor& gmean,
    const Tensor& self,
    const Tensor& std,
    at::OptionalIntArrayRef dim_opt,
    const std::optional<c10::Scalar>& correction_opt,
    bool keepdim) {
  Tensor gself;
  if (gstd.defined()) {
    gself = std_backward(std, gstd, self, dim_opt, correction_opt, keepdim);
  }
  if (gmean.defined()) {
    auto aux = mean_backward(
        gmean,
        self.sym_sizes(),
        dim_opt.value_or(IntArrayRef({})),
        self.sym_numel(),
        keepdim);
    gself = gself.defined() ? gself + aux : std::move(aux);
  }
  return gself;
}

Tensor cholesky_jvp(const Tensor& dA, const Tensor& L, bool upper) {
  at::NoTF32Guard disable_tf32;
  // Let A = LL^H
  // dA = dLL^H + L(dL)^H
  // L^{-1}dA(L^{-H}) = L^{-1}dL + (L^{-1}dL)^H
  //               = sym(L^{-1}dL)
  // where sym(X) = X + X^H
  // A short computation gives that the inverse of sym is given by
  // \pi(X) = X.tril() - 0.5*diag(X)
  // so
  // dL = L\pi(L^{-1}dA(L^{-H}))

  // Precondition: dA is symmetric/Hermitian
  auto L_ = upper ? L.mH() : L;
  auto dL = at::linalg_solve_triangular(L_, dA, /*upper=*/false, /*left=*/true);
  dL = at::linalg_solve_triangular(L_.mH(), dL, /*upper=*/true, /*left=*/false);
  dL = dL.tril() - dL.diagonal(0, -2, -1).mul(0.5).diag_embed();
  dL = L_.matmul(dL);
  return upper ? dL.mH() : std::move(dL);
}

Tensor cholesky_backward(const Tensor& gL, bool upper, const Tensor& L) {
  at::NoTF32Guard disable_tf32;
  // From cholesky_jvp we have that
  // dL = L\pi(L^{-1}dA(L^-H))
  //
  // Let gL be the projection into the lower-triangular gradient wrt L. Taking
  // adjoints we have gA = L^{-H}\pi^*((L^HgL).tril())L^{-1} where \pi^*(X) =
  // 0.5 * (X + X^H - diag(X)) The only non-standard point of this derivation is
  // noting that the adjoint to multiplying on the left by a lower triangular
  // matrix L is multiplying by L^H and then projecting back to the lower
  // triangular matrices (hence the .tril() projection) Note that the gradient
  // is symmetric and not triangular.
  auto L_ = upper ? L.mH() : L;
  auto gL_ = upper ? gL.mH() : gL;

  // Nb. We don't need to compute gL_ = gL.tril() as
  // tril(L^H gL) = tril(L^H (triu(gL, 1) + tril(gL)))
  //              = tril(L^H tril(gL)) + tril(L^H triu(gL, 1))
  //              = tril(L^H tril(gL))
  // since tril(L^H triu(gL, 1)) = 0, as L^H triu(gL, 1) is upper triangular
  auto gA = L_.mH().matmul(gL_).tril();
  // Equivalent to 0.5 * (gA + gA^H - diag(gA))
  gA = 0.5 * (gA + gA.tril(-1).mH());
  gA = at::linalg_solve_triangular(L_.mH(), gA, /*upper=*/true, /*left=*/true);
  gA = at::linalg_solve_triangular(L_, gA, /*upper=*/false, /*left=*/false);
  return gA;
}

Tensor cholesky_inverse_backward(
    const Tensor& grad,
    const Tensor& L,
    bool upper,
    const Tensor& inverse) {
  at::NoTF32Guard disable_tf32;
  Tensor grad_L;
  if (grad.defined()) {
    Tensor common_term = grad + grad.mH();
    common_term = at::matmul(inverse, at::matmul(common_term, inverse));
    if (upper) {
      grad_L = -at::matmul(L, common_term);
    } else {
      grad_L = -at::matmul(common_term, L);
    }
  }

  return grad_L;
}

// If X = (L L^H)^{-1} with L lower-triangular with a real positive diagonal,
// then dX = K^H + K, where
// K =  L^{-H} dL^{-1} [dL^{-1} = -L^{-1} dL L^{-1}]
//   = -L^{-H} L^{-1} dL L^{-1} [L^{-H} L^{-1} = X]
//   = -X dL L^{-1} [X = X^H = L^{-H} L^{-1} = L^{-1} L^{-H}]
//   = -X dL X L^{H}.
// If X = (U^H U)^{-1} with U upper-triangular with a real positive diagonal,
// then K becomes
// K = -X dU^H X U
Tensor cholesky_inverse_jvp(
    const Tensor& F,
    const Tensor& dF,
    const Tensor& X,
    bool upper) {
  at::NoTF32Guard disable_tf32;
  const auto CF = upper ? F : F.mH();
  const auto dCF = upper ? dF.mH() : dF;
  const auto partial_dX = -X.matmul(dCF).matmul(X).matmul(CF);
  return partial_dX + partial_dX.mH();
}

// The formula for forward AD is adapted from
//
// Golub, Gene H., and Victor Pereyra. "The Differentiation of Pseudo-Inverses
// and Nonlinear Least Squares Problems Whose Variables Separate." SIAM Journal
// on Numerical Analysis 10(2). (1973). 413-432. doi: 10.1137/0710036
//
// We present a short derivation below:
// Let Ap := pinv(A), then Ap is the unique matrix such that
//
// Ap A Ap = Ap [1]
// A Ap A = A   [2]
//
// By differentiating [1] we get:
//
// dAp = dAp A Ap + Ap dA Ap + Ap A dAp [3]
//
// In the rhs of [3] the products involving dAp could be expressed as products
// of Ap^i, A^j, dA^k with i, j, k in {1, H}, where X^H = X.mH(). To prove that,
// note (A Ap)^H = A Ap and (Ap A)^H = Ap A, which could be shown by taking the
// product between the SVD decompositions of A and Ap. Consider the
// conjugate-transposed [2]: (A Ap A)^H = A^H (A Ap) = A^H. By differentiating
// it we get: dA^H A Ap + A^H dA Ap + A^H A dAp = dA^H. By multiplying from the
// left by Ap^H and using Ap^H A^H = (A Ap)^H = A Ap: Ap^H dA^H A Ap + A Ap dA
// Ap + A Ap A dAp = Ap^H dA^H. By multiplying from the left by Ap and by
// applying [1] and [2] repeatedly until impossible we get: Ap Ap^H dA^H A Ap +
// Ap dA Ap + Ap A dAp = Ap Ap^H dA^H. By rearranging the terms:
//
// Ap A dAp = -Ap dA Ap + Ap Ap^H dA^H (I - A Ap) [4],
// which is one of the summands in [3].
//
// Similar, by differentiating the transpose-conjugated [2] written differently,
// i.e. (A Ap A)^H = Ap A A^H = A^H we will get an expression for dAp A Ap,
// which is
//
// dAp A Ap = -Ap dA Ap + (I - Ap A) dA^H Ap^H Ap [5].
//
// By plugging in [4] and [5] into [3] we get the forward AD formula for pinv:
//
// dAp = -Ap dA Ap + (I - Ap A) dA^H Ap^H Ap + Ap Ap^H dA^H (I - A Ap).
Tensor pinv_jvp(const Tensor& A, const Tensor& pinvA, const Tensor& dA) {
  at::NoTF32Guard disable_tf32;
  auto m = A.size(-2);
  auto n = A.size(-1);
  auto dAh = dA.mH();
  auto pinvAh = pinvA.mH();
  // optimization to produce matrices of the smallest dimension
  if (m <= n) {
    auto K = pinvAh.matmul(dAh);
    return pinvA.matmul(K - K.mH() - K.matmul(A.matmul(pinvA))) +
        (dAh - pinvA.matmul(A.matmul(dAh))).matmul(pinvAh.matmul(pinvA));
  } else {
    auto K = pinvA.matmul(dA);
    auto Kh = K.mH();
    return (Kh - K - pinvA.matmul(A).matmul(Kh)).matmul(pinvA) +
        (pinvA.matmul(pinvAh)).matmul(dAh - (dAh.matmul(A)).matmul(pinvA));
  }
}

Tensor pinv_backward(const Tensor& grad, const Tensor& pinvA, const Tensor& A) {
  at::NoTF32Guard disable_tf32;
  auto m = A.sym_size(-2);
  auto n = A.sym_size(-1);
  auto pinvAh = pinvA.mH();
  auto gradh = grad.mH();
  // optimization to produce matrices of the smallest dimension
  if (m <= n) {
    auto K = gradh.matmul(pinvA);
    auto KpinvAh = K.matmul(pinvAh);
    return -(pinvA.matmul(K)).mH() + KpinvAh -
        (A.matmul(pinvA)).matmul(KpinvAh) +
        (pinvAh.matmul(pinvA)).matmul(gradh - K.matmul(A));
  } else {
    auto K = pinvA.matmul(gradh);
    auto pinvAhK = pinvAh.matmul(K);
    return -(K.matmul(pinvA)).mH() +
        (gradh - A.matmul(K)).matmul(pinvA).matmul(pinvAh) + pinvAhK -
        pinvAhK.matmul(pinvA).matmul(A);
  }
}

Tensor split_with_sizes_backward(
    const std::vector<torch::autograd::Variable>& grads,
    c10::SymIntArrayRef split_sizes,
    int64_t dim,
    c10::SymIntArrayRef sizes,
    const at::TensorOptions& options) {
  dim = at::maybe_wrap_dim(dim, static_cast<int64_t>(sizes.size()));

  // it's possible some of the grads are not defined (represents tensors of all
  // 0s). Since at::cat can't handle those, let's define them
  std::vector<Tensor> grads_all_defined(grads.size());
  for (const auto j : c10::irange(grads.size())) {
    if (grads[j].defined()) {
      grads_all_defined[j] = grads[j];
    } else {
      const auto& length = split_sizes[j];
      auto grad_size = sizes.vec();
      grad_size[dim] = length;
      grads_all_defined[j] = at::zeros_symint(grad_size, options);
    }
  }

  auto ret = at::cat(grads_all_defined, dim);
  return ret;
}

Tensor _nested_split_with_sizes_backward(
    const std::vector<torch::autograd::Variable>& grads,
    c10::SymIntArrayRef split_sizes,
    int64_t dim,
    const Tensor& nt_sizes,
    const at::TensorOptions& options) {
  // add 1 to account for batch dim
  dim = at::maybe_wrap_dim(dim, static_cast<int64_t>(nt_sizes.size(1)) + 1);
  // it's possible some of the grads are not defined (represents tensors of all
  // 0s). Since at::cat can't handle those, let's define them
  std::vector<Tensor> grads_all_defined;
  for (int64_t i : c10::irange(static_cast<int64_t>(grads.size()))) {
    if (grads[i].defined()) {
      grads_all_defined.push_back(static_cast<Tensor>(grads[i]));
    } else {
      const auto& length = split_sizes[i].guard_int(__FILE__, __LINE__);
      auto nt_split_size = nt_sizes.clone();
      auto nt_split_size_ptr = nt_split_size.data_ptr<int64_t>();
      for (int64_t j : c10::irange(static_cast<int64_t>(nt_sizes.size(0)))) {
        // subtract 1 to account for batch dim
        nt_split_size_ptr
            [j * static_cast<int64_t>(nt_sizes.size(1)) + (dim - 1)] = length;
      }
      Tensor zeros_buffer = at::zeros(
          {at::native::get_numel_from_nested_size_tensor(nt_split_size)},
          options);
      auto nt_split_grad = at::native::wrap_buffer(zeros_buffer, nt_split_size);
      grads_all_defined.push_back(nt_split_grad);
    }
  }

  auto ret = at::cat(grads_all_defined, dim);
  return ret;
}

Tensor split_backward(
    const std::vector<torch::autograd::Variable>& grads,
    const c10::SymInt& split_size,
    int64_t dim,
    c10::SymIntArrayRef sym_sizes,
    const at::TensorOptions& options) {
  dim = at::maybe_wrap_dim(dim, static_cast<int64_t>(sym_sizes.size()));
  const auto& dim_size = sym_sizes[dim];
  auto num_splits = grads.size();
  std::vector<c10::SymInt> split_sizes(num_splits, split_size);
  split_sizes[num_splits - 1] =
      split_size - (split_size * num_splits - dim_size);
  return split_with_sizes_backward(grads, split_sizes, dim, sym_sizes, options);
}

Tensor max_pool_double_backward(
    const Tensor& grad,
    const Tensor& indices,
    int dim) {
  AT_ASSERT(indices.dim() >= dim);
  // handle non-empty inputs
  if (indices.sym_numel() != 0) {
    auto size = indices.sym_sizes().slice(0, indices.dim() - dim).vec();
    size.emplace_back(-1);
    auto indices_view = indices.view_symint(size);
    const auto memory_format = indices.suggest_memory_format();
    return grad.contiguous(memory_format)
        .view_symint(size)
        .gather(-1, indices_view)
        .view_symint(indices.sym_sizes());
  }
  // handle empty inputs
  else {
    return at::empty_like(indices, grad.options());
  }
}

Tensor error_for_max_pool2d_double_backward() { // This is mps-only.
  TORCH_CHECK(
      false,
      "max_pool2d with `return_indices=False` is not infinitely differentiable.",
      " If you want to calculate higher order derivatives, e.g. second order,",
      " set `return_indices=True`.");
  return Tensor();
}

Tensor glu_double_backward(
    const Tensor& grad,
    const Tensor& grad_output,
    const Tensor& input,
    int64_t dim) {
  auto& gO = grad_output;
  auto input_size = input.size(dim) / 2;
  auto first_half = input.narrow(dim, 0, input_size);
  auto second_half = input.narrow(dim, input_size, input_size);
  auto sig_second_half = second_half.sigmoid();
  auto one_sub_sig_second_half = 1 - sig_second_half;
  auto sig_one_sub_sig = sig_second_half * one_sub_sig_second_half;

  auto ggI_first_half = grad.narrow(dim, 0, input_size);
  auto ggI_second_half = grad.narrow(dim, input_size, input_size);
  auto ggI_second_half_times_first_half = ggI_second_half * first_half;

  auto gI_first_half = ggI_second_half * gO * sig_one_sub_sig;
  auto second_order_sh = sig_one_sub_sig * one_sub_sig_second_half -
      sig_second_half * sig_one_sub_sig;
  auto gI_second_half =
      ggI_second_half_times_first_half * gO * second_order_sh +
      ggI_first_half * gO * sig_one_sub_sig;
  return at::cat({std::move(gI_first_half), std::move(gI_second_half)}, dim);
}

Tensor glu_double_backward_grad_output(
    const Tensor& grad,
    const Tensor& input,
    int64_t dim) {
  if (dim < 0)
    dim += input.dim();
  auto sizes = input.sizes().vec();
  sizes[dim] /= 2;
  auto tmp = grad * glu_backward(at::ones(sizes, input.options()), input, dim);
  return tmp.narrow(dim, 0, sizes[dim]) +
      tmp.narrow(dim, sizes[dim], sizes[dim]);
}

Tensor infinitely_differentiable_silu_backward(
    const Tensor& grad_output,
    const Tensor& input) {
  const Tensor sigmoid = input.sigmoid();
  return grad_output * sigmoid * (1.0 + input * (1.0 - sigmoid));
}

Tensor infinitely_differentiable_mish_backward(
    const Tensor& grad_output,
    const Tensor& input) {
  const Tensor sigmoid = input.sigmoid();
  const Tensor softplus = input.exp().log1p();
  const Tensor tanh_softplus = softplus.tanh();
  return grad_output *
      (tanh_softplus + input * sigmoid * (1.0 - tanh_softplus * tanh_softplus));
}

Tensor infinitely_differentiable_logit_backward(
    const Tensor& grad,
    const Tensor& self,
    std::optional<double> eps) {
  if (eps) {
    const double lo = eps.value();
    const double hi = 1.0 - lo;
    return at::where(
        at::logical_and(self >= lo, self <= hi),
        grad / (self * (1.0 - self)),
        at::zeros({}, self.options()));
  } else {
    return at::where(
        at::logical_and(self >= 0.0, self <= 1.0),
        grad / (self * (1.0 - self)),
        at::empty({}, self.options())
            .fill_(std::numeric_limits<double>::quiet_NaN()));
  }
}

Tensor binary_cross_entropy_target_backward(
    const Tensor& grad,
    const Tensor& self,
    const Tensor& target,
    const std::optional<Tensor>& weight,
    int64_t reduction) {
  auto grad_target = at::logit(self).neg_();

  if (!areAnyTensorSubclassLike({grad})) {
    grad_target.mul_(grad);
  } else {
    grad_target = grad_target * grad;
  }

  if (isDefined(weight)) {
    // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
    if (!isTensorSubclassLike(weight.value())) {
      // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
      grad_target.mul_(weight.value());
    } else {
      // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
      grad_target = grad_target * weight.value();
    }
  }

  if (reduction == at::Reduction::Mean) {
    grad_target.div_(target.sym_numel());
  }

  return grad_target;
}

Tensor binary_cross_entropy_double_backward_target(
    const Tensor& grad,
    const Tensor& grad_output,
    const Tensor& self,
    const Tensor& target,
    const std::optional<Tensor>& weight,
    int64_t reduction) {
  auto res = -grad * grad_output;

  if (isDefined(weight)) {
    // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
    res = isTensorSubclassLike(weight.value())
        // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
        ? res.mul(weight.value())
        // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
        : res.mul_(weight.value());
  }

  auto neg_self = 1 - self;
  auto denom =
      isTensorSubclassLike(self) ? neg_self.mul(self) : neg_self.mul_(self);
  {
    at::NoGradGuard guard;
    // Default eps in binary_cross_entropy for ALL dtypes
    // TODO: probably change this to a dtype-dependent value
    double eps = 1e-12;
    denom.clamp_min_(eps);
  }

  res = isTensorSubclassLike(denom) ? res.div(denom) : res.div_(denom);

  if (reduction == at::Reduction::Mean) {
    res.div_(target.sym_numel());
  }

  return res;
}

Tensor binary_cross_entropy_with_logits_backward(
    const Tensor& grad,
    const Tensor& input,
    const Tensor& target,
    const std::optional<Tensor>& weight,
    const std::optional<Tensor>& pos_weight,
    int64_t reduction) {
  // Trivial case
  if (grad._is_zerotensor()) {
    return at::_efficientzerotensor(input.sizes(), input.options());
  }

  // -w * [ pos * y * (1 -sigmoid(x)) - (1 - y) sigmoid(x)] * grad

  // If there are subclassed tensors use the out of place version
  Tensor grad_input;
  if (isDefined(pos_weight)) {
    // pos_weight might need to be broadcasted, thus mul(target) is not inplace.
    // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
    auto t = pos_weight->mul(target);
    grad_input = at::areAnyTensorSubclassLike({input, target}) ||
            at::GradMode::is_enabled()
        ? t.add(1).sub(target).mul(input.sigmoid()).sub(t)
        : t.add(1).sub_(target).mul_(input.sigmoid()).sub_(t);
  } else {
    grad_input = at::areAnyTensorSubclassLike({input, target}) ||
            at::GradMode::is_enabled()
        ? input.sigmoid().sub(target)
        : input.sigmoid().sub_(target);
  }

  if (at::isTensorSubclassLike(grad) || at::GradMode::is_enabled()) {
    grad_input = grad_input.mul(grad);
  } else {
    grad_input.mul_(grad);
  }

  if (isDefined(weight)) {
    // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
    if (at::isTensorSubclassLike(*weight) || at::GradMode::is_enabled()) {
      // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
      grad_input = grad_input.mul(*weight);
    } else {
      // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
      grad_input.mul_(*weight);
    }
  }

  if (reduction == at::Reduction::Mean) {
    grad_input.div_(input.sym_numel());
  }

  return grad_input;
}

Tensor binary_cross_entropy_with_logits_target_backward(
    const Tensor& grad_output,
    const Tensor& self,
    const Tensor& target,
    const std::optional<Tensor>& weight,
    const std::optional<Tensor>& pos_weight,
    int64_t reduction) {
  if (grad_output._is_zerotensor()) {
    return at::_efficientzerotensor(target.sizes(), target.options());
  }

  Tensor grad_target;
  if (isDefined(pos_weight)) {
    // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
    if (areAnyTensorSubclassLike({*pos_weight, grad_output})) {
      grad_target = at::log_sigmoid(-self)
                        // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
                        .sub(at::log_sigmoid(self).mul(*pos_weight))
                        .mul(grad_output);
    } else {
      grad_target = at::log_sigmoid(-self)
                        // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
                        .sub_(at::log_sigmoid(self).mul_(*pos_weight))
                        .mul_(grad_output);
    }
  } else {
    grad_target = -self * grad_output;
  }

  if (isDefined(weight)) {
    // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
    if (at::isTensorSubclassLike(*weight)) {
      // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
      grad_target = grad_target.mul(*weight);
    } else {
      // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
      grad_target.mul_(*weight);
    }
  }

  if (reduction == at::Reduction::Mean) {
    grad_target.div_(target.sym_numel());
  }

  return grad_target;
}

Tensor log_sigmoid_double_backward(const Tensor& grad, const Tensor& input) {
  auto z = input.sigmoid();
  return grad * (z - 1) * z;
}

Tensor softmax_double_backward(
    const Tensor& grad,
    const Tensor& grad_output,
    int dim,
    const Tensor& output) {
  return grad_output * grad - (output * grad_output).sum(dim, true) * grad -
      grad_output * (output * grad).sum(dim, true);
}

// NOTE: [How to write vmap-compatible backward formulas]
//
// See NOTE: [vmap-incompatible in-place operations] for what it means for an
// in-place operation to be incompatible with vmap.
//
// If an in-place operation used in a backward formula is vmap-incompatible,
// then as developers we have the following options:
//
// - If the in-place operation directly followed the creation of a tensor with
//   a factory function like at::zeros(...), we should replace the factory with
//   a corresponding grad.new_zeros(...) call. The grad.new_zeros(...) call
//   propagates the batch dims to the resulting tensor.
//   For example:
//     Before: at::zeros(input.sizes(), grad.options()).copy_(grad)
//     After:  grad.new_zeros(input.sizes()).copy_(grad)
//
// - If the in-place operation followed some sequence of operations, if the
//   we want to be able to vmap over the backward formula as-is (this is
//   usually the case for simple (<15loc) backward formulas), then use
//   areAnyTensorSubclassLike  to guard the operation. For example:
//             c = a * b
//     Before: c.mul_(grad)
//     After:  c = !areAnyTensorSubclassLike({c, grad}) ? c.mul_(grad) : c *
//     grad
//
// - If we don't want to vmap directly over the backward formula (e.g., if the
//   backward formula is too complicated or has a lot of vmap-incompatible
//   operations, then register the backward formula as an operator and
//   eventually write a batching rule for it.

Tensor binary_cross_entropy_double_backward(
    const Tensor& grad_output,
    const Tensor& grad,
    const Tensor& input,
    const Tensor& target,
    const std::optional<Tensor>& weight,
    int64_t reduction) {
  auto eps = 1e-12;
  auto inp_pl_eps = input + eps;
  auto one_m_inp_pl_eps = 1 - input + eps;
  // gradient wrt input
  auto gI = (input * input - 2 * input * target + target) /
      (inp_pl_eps.pow(2) * one_m_inp_pl_eps.pow(2));
  if (!areAnyTensorSubclassLike({gI, grad})) {
    gI *= (grad * grad_output);
  } else {
    gI = gI * (grad * grad_output);
  }

  if (isDefined(weight)) {
    // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
    if (!isTensorSubclassLike(*weight)) {
      // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
      gI *= *weight;
    } else {
      // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
      gI = gI.mul(*weight);
    }
  }
  if (reduction == at::Reduction::Mean) {
    return gI / input.sym_numel();
  }

  return gI;
}

Tensor binary_cross_entropy_double_backward_grad_output(
    const Tensor& grad,
    const Tensor& input,
    const Tensor& target,
    const std::optional<Tensor>& weight,
    int64_t reduction) {
  auto eps = 1e-12;
  // gradient wrt grad_output
  auto ggO = (input - target) / ((input + eps) * (1 - input + eps));
  if (!areAnyTensorSubclassLike({ggO, grad})) {
    ggO *= grad;
  } else {
    ggO = ggO * grad;
  }

  if (isDefined(weight)) {
    // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
    if (!isTensorSubclassLike(*weight)) {
      // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
      ggO *= *weight;
    } else {
      // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
      ggO = ggO.mul(*weight);
    }
  }
  if (reduction == at::Reduction::Mean) {
    return ggO / input.sym_numel();
  }
  return ggO;
}

Tensor smooth_l1_loss_double_backward(
    const Tensor& grad,
    const Tensor& input,
    const Tensor& target,
    int64_t reduction,
    double beta) {
  // special case to protect against a divide-by-zero.
  if (beta == 0) {
    return at::zeros(grad.sizes(), grad.options());
  }
  auto d = (input - target).abs();
  auto grad_input = grad * (d < beta).type_as(grad) / beta;
  if (reduction == at::Reduction::Mean) {
    grad_input /= input.sym_numel();
  }
  return grad_input;
}

Tensor huber_loss_double_backward(
    const Tensor& grad,
    const Tensor& input,
    const Tensor& target,
    int64_t reduction,
    double delta) {
  auto d = (input - target).abs();
  auto grad_input = grad * (d < delta);
  if (reduction == at::Reduction::Mean) {
    grad_input /= input.sym_numel();
  }
  return grad_input;
}

Tensor huber_loss_double_backward_grad_output(
    const Tensor& grad,
    const Tensor& grad_output,
    const Tensor& input,
    const Tensor& target,
    int64_t reduction,
    double delta) {
  if (reduction == at::Reduction::None) {
    return huber_loss_backward(grad, input, target, reduction, delta);
  }
  auto r = huber_loss_backward(
      ones_like(grad_output), input, target, reduction, delta);
  return (r * grad).sum();
}

Tensor mse_loss_double_backward(
    const Tensor& grad,
    const Tensor& input,
    int64_t reduction) {
  auto grad_input = 2 * grad;
  if (reduction == at::Reduction::Mean) {
    grad_input /= input.sym_numel();
  }
  return grad_input;
}

Tensor soft_margin_loss_double_backward(
    const Tensor& grad,
    const Tensor& input,
    const Tensor& target,
    int64_t reduction) {
  auto z = (input * -target).exp();
  auto zplus1 = z + 1;
  auto grad_input = grad * (target * target) * z / (zplus1 * zplus1);
  if (reduction == at::Reduction::Mean) {
    grad_input /= input.sym_numel();
  }
  return grad_input;
}

Tensor soft_margin_loss_double_backward_grad_output(
    const Tensor& grad,
    const Tensor& grad_output,
    const Tensor& input,
    const Tensor& target,
    int64_t reduction) {
  if (reduction == at::Reduction::None) {
    return soft_margin_loss_backward(grad, input, target, reduction);
  }
  auto r = soft_margin_loss_backward(
      ones_like(grad_output), input, target, reduction);
  return (r * grad).sum();
}

Tensor softplus_double_backward(
    const Tensor& grad,
    const Tensor& input,
    const Scalar& beta,
    const Scalar& threshold) {
  auto x = (input * beta);
  return sigmoid_backward(grad, x.sigmoid()) * (x < threshold).type_as(grad) *
      beta;
}

// NOTE [ as_strided Backward and layout-aware/agnostic autograd ]
//
// `storage_offset` is ignored for simplicity in this note. If you just want the
// full algorithm without explanation, scroll down to bottom of this note.
//
// Implementing the backward of as_strided is tricky because you have to deal
// with mappings that map one memory location to multiple indices, i.e., the
// output tensor has multiple indices pointing to **overlapping** memory
// addresses. This can happen in all in all sorts of weird cases. For example,
//
//   x = torch.randn(15)
//   x.as_strided([3, 3], [1, 0])  # "expand" case
//   x.as_strided([3, 3], [2, 1])  # "size too large" case
//   x.as_strided([3, 2], [3, 6])  # res[2, 0] points to 2*3 + 0*6 = 6
//                                 # res[0, 1] points to 0*3 + 1*6 = 6
//
// Here is the general strategy we apply in implementing as_strided backward:
//   0. ??? (optimization step. we will talk about this later)
//   1. Create some underlying flattened tensor as if it is the base tensor
//      representing the contiguous memory storage for both input and output.
//   2. Use the output geometry to scatter (or index_add) the gradients into
//      this storage tensor.
//   3. ??? (fix for input tensor with overlapping memory. we will talk about
//           this later)
//   4. Return the as_strided view of the storage tensor using input geometry.
//
// In step (2), if the output tensor doesn't have overlapping memory, we can
// safely scatter (`storage.as_strided(output_geometry).copy_(grad)`);
// otherwise, we must use `index_add` as gradients at different indices may need
// to be summed to a single location.
//
// For example, in this case:
//
//   x = torch.randn(3)
//   y = x.as_strided([3, 3], [1, 0])  # "expand" case
//                                     # size   [ 3, 3]
//                                     # stride [ 1, 0]
//   y.backward()  # step (1): contiguous storagte tensor `s` of size 3, which
//                             is large enough to be used as underlying storage
//                             for `x` and `y`.
//                               s = [ 0, 0, 0]
//                 # step (2): since `y` has overlapping memory, index_add grad
//                             into `s` basing on `y`'s geometry, i.e.,
//                             s[i * y.stride(0) + j * y.stride(1)] += gy[i, j].
//                               s = [ 3, 3, 3]
//                 # step (4): as_strided view `s` using `x`'s geometry
//                               s = [ 3, 3, 3]
//                               grad_input = s.as_strided(x.size(), x.stride())
//                                          = s.as_strided([3], [1])
//                                          = [ 3, 3, 3]
//
// This is exactly what we would get if using `expand`. However, here the input
// tensor doesn't have overlapping memory. If it does, we must add an extra step
// before (4). Considering this case:
//
//   t = torch.randn(3)
//   x = t.expand(3, 3)            # input with overlapping memory
//                                 # size   [3, 3]
//                                 # stride [0, 1]
//   y = x.as_strided([1], [1])    # contiguous output
//                                 # size   [1]
//                                 # stride [1]
//   y.backward()  # step (1): contiguous storage tensor `s` of size 3, which
//                             is large enough to be used as underlying storage
//                             for `x` and `y`.
//                               s = [ 0, 0, 0]
//                 # step (2): scatter grad into `s` basing on `y`'s geometry
//                               s = [ 1, 0, 0]
//                 # step (4): as_strided view `s` using `x`'s geometry
//                               s = [ 1, 0, 0]
//                               grad_input = s.as_strided([3, 3], [0, 1])
//                                          = s.as_strided([3, 3], [0, 1])
//                                          = [[ 1, 0, 0],
//                                             [ 1, 0, 0],
//                                             [ 1, 0, 0]]
// Is this result correct?
//
// `x.as_strided([1], [1])` call is obviously equivalent with
// `x[(0,) * x.dim()].view(1)` for any `x`. But autograd through the second
// gives gradient `[ [ 1, 0, 0], [ 0, 0, 0], [ 0, 0, 0]]`. For this specific
// case, indexing `x` at any index in first column is also equivalent, and
// yields a gradient of shape `[3 x 3]` containing eight 0's and one 1. There is
// an `x.size(1)`-times difference between these gradients computed from other
// PyTorch ops and the gradient we got from as_strided.
//
// You might conclude that the gradients from as_strided is wrong. However,
// let's first see why they are actually reasonable. Consider the pointwise
// perturbations by `delta` anywhere in the first column of `x`. It will lead to
// a `delta` change in the same memory location, and then `y` will change by
// `delta`. So one can say the gradient should be exactly 1 at the first column,
// as given by our above procedure.
//
// In the above computation of numerical gradients, they only match the
// analytical results because strides and memory locations are considered in the
// forward pass, i.e., this op (including both forward and backward) is
// layout-aware.
//
// However, in PyTorch, most (probably all) other ops (forward and backward) are
// layout-agnostic. E.g.,
//
//   t = torch.randn(1)
//   x = t.expand(2)
//   y = x.sum()
//   y.backward()
//
// Layout-agnostic autograd (as it is currently in PyTorch) will give you
//
//   gy = 1
//   gx = [ 1, 1]  # SumBackward:    torch.ones_like(x)
//   gt = [ 2]     # ExpandBackward: gx.sum()
//
// Note that `gx = [ 1, 1]`. However, if you perturb any value in `x` by `delta`
// (the other will also change by `delta`), `y` will change by `2 * delta`. So
// the gradients, if strides are taken into consideration, should be 2.
//
// Layout-aware autograd should give you
//
//   gy = 1
//   gx = [ 2, 2]  # Because the backward considers the fact that the input `x`
//                 # is already expanded.
//   gt = [ 2]     # Layout-aware backward of expand is just a slicing because
//                 # the previous backward should have already taken care of
//                 # strides and made sure that gradients are the same along the
//                 # expanded dimension.
//
// As shown above, these two types are not compatible. Therefore, we must either
// make as_strided layout-agnostic, or make all other ops layout-aware.
//
// It is difficult to support layout-aware autograd (at least in the current
// codebase structure), because it would mean
//   1. storing tensor geometries of every input tensor for backward
//   2. depending on input geometry, the gradient computed from backward change
//   3. ideally enforcing gradient of T to always have same strides as T
// (although these two methods only differ when it comes to overlapping memory)
//
// Therefore, we must formulate `as_strided` in a layout-agnostic way, i.e.,
// giving the same output regardless of the input layout. We consider
// `input.stride()` as a separate independent fixed argument `input_stride`.
// Then, `as_strided(input, size, stride)` can be thought of as:
//   1. "Scatter" each value of `input` into a "storage" using storage location
//      computed from the value's index in `input`, `input.size()` and
//      `input_stride`, but if N values end up in the same location, the value
//      is average of those N values (they will be the same value anyways).
//
//      Formal description:
//        Denote the set of all input indices that pointing to the same storage
//        location `storage[n]` as `S(n)`, i.e.,
//
//            S(n) = { index : <index, input_stride> == n, index is valid given
//            input.size() },
//
//        where `<x, y>` is the dot product between `x` and `y`.
//
//        Then, the process is:
//
//            storage[n] = Avg { S(n) }
//
//        Note that all values in `S(n)` are the same (they point to the same
//        memory location anyways, so this step doesn't change anything, but
//        effectively avoids having the dependency on the layout of `input`.
//        I.e., the result holds fixed regardless of the layout of `input`, as
//        long as `input_stride` is fixed.
//
//      NOTE: for forward pass, we can equivalently simply select any one of
//            `S(n)` as `storage[n]`. However, considering this as an average
//            operation makes backward easier (so all values in set
//            `{ grad_input[i] : i in S(n) }` are the same, and it can use the
//            same geometry as input).
//   2. As usual, return the as_strided view of `storage` using required output
//      `size` and `stride`.
//
// To backward through this layout-agnostic version, we simply add the following
// step:
//   .... (scatter gradients into the storage tensor using output geometry)
//   3. For all storage location n, `storage[n] /= |S(n)|`.
//   .... (return as_strided view of the storage tensor using input geometry)
//
// Finally, we note that these general operations are expensive, so we apply the
// following optimizations:
//   Add step (0): For all output dimension `d` with output stride 0, sum the
//                 gradients along dimension `d` (don't keepdim), and remove
//                 dimension `d` from output size and stride.
//                 (An optimization for "expand" cases so we may avoid step (3))
//  Only apply step (3) when input tensor has overlapping memory.
//
// FULL ALGORITHM:
//   0. For all output dimension `d` with output stride 0, sum the gradients
//       along dimension `d` (don't keepdim), and remove dimension `d` from
//       output size and stride.
//   1. Create some underlying flattened tensor as if it is the base tensor
//      representing the contiguous memory storage for both input and output.
//   2. Use the output geometry to scatter (or index_add) the gradients into
//      this storage tensor `storage`.
//   3. If input tensor has overlapping memory,
//      For all storage location `i`, `storage[i] /= N(i)`, where `N(i)` is the
//      number of indices in input geometry pointing to the same storage
//      location `i` (i.e., `|S(i)|` in equations above).
//   4. Return the as_strided view of the storage tensor using input geometry.
//
// See NOTE [ Detecting Memory Overlap Within A Strided Tensor ] on how to
// roughly detech overlapping memory.

// NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
//
// Checking memory overlap within a strided tensor is the special case of
// detecting memory overlap of two strided tensors, where the two tensors start
// at the same memory address. The later is HARD (see #8212).
//
// But even this special case isn't simple. This note describes a check for a
// even more constrained simple case where we can be certain that there is no
// overlap.
//
// The checking algorithm can be described as:
//   0. Return [ pass check ] if any dimension has size 0
//   1. Ignore all dimensions that have size 1
//   2. If no remaining dimensions, return [ pass check ]
//   3. Sort the remaining dimensions according to the strides decreasingly
//   4. Check that for each dimension k,
//
//           stride[k] > \sum_{ i > k } (size[i] - 1) * stride[i]
//
//      That is equivalent to, after reordering the dimensions so strides are
//      in decreasing order, checking that stride of each dimension is larger
//      than the maximum memory offset in a slice at that dimension.
//
// Obviously this check passes for contiguous tensors ( the dimensions will be
// already sorted with LHS = stride[0] = \prod size[i] being exactly 1 larger
// than RHS ). Similarly, the check passes for tensors contiguous in all but
// the last dimension, and LHS = stride[0] = stride[-1] * \prod size[i] being
// exactly stride[-1] larger than RHS. (*)
//
// We will show that these view operations, including all our view operations
// *except for* general as_strided and unfold, also preserve this invariant:
//
//  alias:      Obviously preserves
//
//  expand:     All changed dimensions are removed in step (1)
//
//  view:       Consider the input dimensions as grouped into consecutive
//              dimension "blocks", where dimensions are contiguous in each one.
//              one. view only works when the output dimensions can also be
//              grouped into the same consecutive blocks of same ordering.
//
//              NB: this means that the number of elements and stride of the
//                  last dimension in each block is the same in input and
//                  output. (**)
//
//              Notation:
//                Consider a single such block B,
//                    ... B_prev[-1]], [ B[0], ..., B[i], ..., B[k] = B[-1] ], [
//                    B_next[0], ...
//                                start--^^^^                  ^^^^^^^^^^^^--end
//                Each B[i] denotes a dimension index such that B[i] = B[0] + i.
//
//              We first show that in a tensor (i.e., input) satisfies the
//              invariant, after sorting, the dimensions within each block
//              still remain consecutive. (***)
//
//                After removing dimensions of size 1, the dimensions within a
//                block is already sorted by strides in descending order. So
//                sorting all dimensions will not change the relative ordering
//                among them.
//
//                Assume that some block B is not consecutive after sorting,
//                i.e., there exists a dimension d between B[0] and B[-1] in
//                sorted order.
//
//                By (*), we know that
//                       stride[B[0]]
//                    =  \sum_{i > 0}   (size[B[i]] - 1) * stride[B[i]] +
//                    stride[B[-1]] <  \sum_{i > 0}   (size[B[i]] - 1) *
//                    stride[B[i]] + stride[d]
//                    <= \sum_{i > 0}   (size[B[i]] - 1) * stride[B[i]] +
//                    (size[d] - 1) * stride[d]
//                    <= \sum{j > B[0]} (size[j]    - 1) * stride[j],
//
//                where the first <   comes from sorting and
//                      the second <= comes from the fact that dimension d
//                                               exists after step (1) and
//                                               thus must have size greater
//                                               than 1
//                      the third  <= comes from the fact that each term in
//                                               the sum is non-negative
//
//                Then we have a countradiction as the invariant must not be
//                satisfied at B[0]. So the original proposition is true.
//
//              Now that we established the above claim (***), we consider the
//              view operation as first sorting the dimensions (i.e., blocks),
//              apply the original view (since it only cares dimensions being
//              consecutive and contiguous withtin each block), and then undo
//              the sort.
//
//              Consider a single block B in the output,
//                  ... ], [ B[0], ..., B[i], ..., B[k] = B[-1] ], [ ...
//                    start--^^^^                  ^^^^^^^^^^^^--end
//
//              By (*), we know that for all i
//                  stride[i] = stride[B[-1]] +
//                                \sum_{j=i+1}^{k} (size[B[j]] - 1) *
//                                stride[B[j]]
//
//              Then the invariant is obviously satisfied at every dimension
//              in this block if it is satisfied at dimension B[-1]. It only
//              remains to show that it is satisfied at the last dimension in
//              each block.
//
//              Since the same blocks are present in both input and output
//              with the same ordering, we will abuse the notation in the
//              following statements.
//
//              By (*), we know that the following holds for both input and
//              output, for any block B:
//                    \sum_{i > B[-1]} (size[i] - 1) * stride[i]
//                  = \sum_{block B' after B} \prod_{j in B'} size[B[j]] *
//                  stride[B'[-1]] = \sum_{block B' after B} numel(B') *
//                  stride[B'[-1]].
//                    ^^^^^^^^^^^^^^^^^^^^^^^|^^^^^^^^^^^^^^^^^^^^^^^^^^
//              By (**), we know that, this quantity in the above equation
//              remains the same in input and output. So both
//                  \sum_{i > B[-1]} (size[i] - 1) * stride[i]
//              and
//                  stride[B[-1]]
//              are the same in input and output.
//
//              These two quantities are exactly the LHS and RHS of the
//              invariant inequality. Since by assumption the invariant is
//              satisfied in input at B[-1], it is also satisfied in output at
//              B[-1]. This concludes the proof.
//
//  squeeze:    Special case of view
//
//  unsqueeze:  Special case of view
//
//  slice:      Consider slicing dimension i with step = k >= 1.
//
//              Let stride' and size' be the output strides and sizes. We have
//
//                  stride'[i] = k * stride[i]
//                  size'[i] <= floor(size[i] / k)
//
//              If size'[i] = 1, invariant is obviously satisfied as we are
//              just removing a dimension (afte step (1)).
//
//              Assume size'[i] > 1.
//
//              By assumption, the invariant is satisfied at every dimension
//              in input.
//
//              For any dimension j, if stride[j] > stride[i], we have
//                  stride'[j] =  stride[j]
//                             >  (size[i] - 1) * stride[i]
//                             =  (size[i] / k * k - 1) * k * stride[i] / k
//                             =  (size[i] / k - 1 / k) * stride'[i]
//                             >= (size'[i]    - 1 / k) * stride'[i]
//                             >= stride'[i].
//
//              If stride[j] < stride[i], we have
//                  stride'[j] = stride[j] < stride[i] <= stride'[i].
//
//              So the sorting order remains unchanged after slice.
//
//              Since
//                     (size'[i] - 1) * stride'[i]
//                  =  (floor(size[i] / k) - 1) * k * stride[i]
//                  <= (size[i] / k - 1) * k * stride[i]
//                  =  (size[i] - k) * stride[i]
//                  <= (size[i] - 1) * * stride[i],
//              the term from this dimension i in the invariant inequality at
//              other dimensions can only decrease after slice. So the
//              invariant is preserved.
//
//  narrow:     Special case of slice
//
//  select:     narrow + squeeze
//
//  permute:    Sorting makes permutation of dimensions irrelevant
//
//  transpose:  Sorting makes swapping dimensions irrelevant
//
//  diagonal:   Effectively merging two dimensions i and j into a new
//              dimension k s.t.
//                  stride'[k] =  stride[i] + stride[j]
//                  size'[k]   <= min(size[i], size[j]),
//              where stride and size are on the input, and stride' and size'
//              are on the output.
//
//              Assuming that size[i] > 1 and size[j] > 1. If any has size 1,
//              then this is unsqueeze on that dimension.
//
//              WLOG, say stride[i] >= stride[j].
//
//              Each dimension d in input with stride[d] > stride[j] has
//                  stride'[d] =  stride[d]
//                             >  (size[i] - 1) * stride[i] + (size[j] - 1) *
//                             stride[j]
//                             >= stride[i] + stride[j]
//                             =  stride[k].
//              So, considering the sorted dimensions, this is effectively
//              removing i, and replacing j with k.
//
//              For dimensions d with stride[i] < stride[d] < stride[j], the
//              term from dimension i is removed in the invariant inequality.
//              For dimensions d with stride[d] > stride[j], we have
//                     (size'[k] - 1) * stride'[k]
//                  <= (min(size[i], size[j]) - 1) * (stride[i] + stride[j])
//                  <= (size[i] - 1) * stride[i] + (size[j] - 1) * stride[j],
//              so the term from i and j in the invariant can only decrease.
//
//              So this is generally relaxing the constraint, and thus it
//              preserves it.

// This implements steps (2)~(4) of the algorithm in
// NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
// Helper for as_strided_backward
static inline bool _maybe_overlapping_memory(
    c10::SymIntArrayRef sizes,
    c10::SymIntArrayRef strides) {
  if (!sizes.empty()) {
    std::vector<std::size_t> argsort(sizes.size());
    std::iota(argsort.begin(), argsort.end(), 0);
    std::sort(
        argsort.begin(), argsort.end(), [&](std::size_t i, std::size_t j) {
          return strides[i] < strides[j];
        });

    c10::SymInt max_index_in_slice = 0;
    for (auto i : argsort) {
      const auto& stride_ = strides[i];
      if (stride_ <= max_index_in_slice) {
        return true;
      }
      max_index_in_slice += stride_ * (sizes[i] - 1);
    }
  }
  return false;
}

// Returns the minimum storage size needed to contain a tensor of sizes,
// strides, and storage_offset Helper for as_strided_backward
static inline c10::SymInt _min_storage_size(
    c10::SymIntArrayRef sizes,
    c10::SymIntArrayRef strides,
    c10::SymInt storage_offset) {
  c10::SymInt storage_size = storage_offset + 1;
  auto dim = sizes.size();
  for (const auto i : c10::irange(dim)) {
    const auto& size_i = sizes[i];
    if (size_i == 0) {
      return storage_offset;
    }
    storage_size += (size_i - 1) * strides[i];
  }
  return storage_size;
}

// See NOTE [ as_strided Backward and layout-aware/agnostic autograd ] for
// explanation
Tensor as_strided_backward(
    Tensor grad,
    const TensorGeometry& input_geometry,
    c10::SymIntArrayRef sym_sizes,
    c10::SymIntArrayRef sym_strides,
    const std::optional<c10::SymInt>& sym_storage_offset_) {
  // For output geometry,
  //   check for size 0 dimensions,
  //   skip size 1 dimensions,
  //   reduce grad on expanded dims (stride=0, size>1)
  // Step (0)     for the algorithm in NOTE [ as_strided Backward and
  // layout-aware/agnostic autograd ] Step (0)~(1) for the algorithm in NOTE [
  // Detecting Memory Overlap Within A Strided Tensor ]
  //              on output geometry
  auto sym_storage_offset =
      sym_storage_offset_.value_or(input_geometry.sym_storage_offset());
  auto odim = grad.dim();
  std::vector<c10::SymInt> out_sizes_, out_strides_;
  out_sizes_.reserve(odim);
  out_strides_.reserve(odim);
  for (int64_t i = odim - 1; i >= 0; i--) {
    const auto& size_i = sym_sizes[i];
    const auto& stride_i = sym_strides[i];
    if (size_i == 0) {
      return at::zeros_symint(input_geometry.sym_sizes(), grad.options());
    } else if (size_i == 1) {
      grad = grad.squeeze(i);
    } else if (stride_i == 0) {
      grad = grad.sum(i, false);
    } else {
      out_sizes_.insert(out_sizes_.begin(), size_i);
      out_strides_.insert(out_strides_.begin(), stride_i);
    }
  }
  // Step (2)~(4) for the algorithm in NOTE [ Detecting Memory Overlap Within A
  // Strided Tensor ]
  //              on output geometry
  auto out_maybe_overlap = _maybe_overlapping_memory(out_sizes_, out_strides_);

  // For input geometry,
  //   check for size 0 dimensions,
  //   skip size 1 dimensions,
  // Step (0)~(1) for the algorithm in NOTE [ Detecting Memory Overlap Within A
  // Strided Tensor ]
  //              on input geometry
  auto idim = input_geometry.dim();
  auto inp_sizes = input_geometry.sym_sizes(),
       inp_strides = input_geometry.sym_strides();
  std::vector<c10::SymInt> inp_sizes_, inp_strides_;
  inp_sizes_.reserve(idim);
  inp_strides_.reserve(idim);
  for (int64_t i = idim - 1; i >= 0; i--) {
    const auto& size_i = inp_sizes[i];
    const auto& stride_i = inp_strides[i];
    if (size_i == 0) {
      return at::zeros_symint(input_geometry.sym_sizes(), grad.options());
    } else if (size_i != 1) {
      inp_sizes_.insert(inp_sizes_.begin(), size_i);
      inp_strides_.insert(inp_strides_.begin(), stride_i);
    }
  }
  // Step (1)~(4) for the algorithm in NOTE [ Detecting Memory Overlap Within A
  // Strided Tensor ]
  //              on input geometry
  auto inp_maybe_overlap = _maybe_overlapping_memory(inp_sizes_, inp_strides_);

  // Rest of this function implements
  // Step (1)~(4) for the algorithm in NOTE [ as_strided Backward and
  // layout-aware/agnostic autograd ]
  // TODO: Raise if not all output values are visible in input geometry.
  //       Technically speaking, if you treat those values as constants, not
  //       raising is fine, and mathematically correct. However, these values
  //       really are contained in some base tensor, and by treating them as
  //       constants we are ignoring this tight dependency. Therefore, it is
  //       more sensible to raise here.

  // Step (1): create underlying tensor as "storage"
  auto shared_offset =
      // TODO: symint-ify. Do we need a min() and max() for SymInts?
      input_geometry.sym_storage_offset().min(sym_storage_offset);
  auto inp_effective_offset =
      input_geometry.sym_storage_offset() - shared_offset;
  auto out_effective_offset = sym_storage_offset - shared_offset;
  auto base_size1 =
      _min_storage_size(inp_sizes_, inp_strides_, inp_effective_offset);
  auto base_size2 =
      _min_storage_size(out_sizes_, out_strides_, out_effective_offset);
  auto base_size = base_size1.max(base_size2);
  auto storage = grad.new_zeros_symint(c10::SymIntArrayRef(base_size));

  // prepare indices tensor if we will do index_add_ later
  std::optional<at::Tensor> flatten_full_indices;
  if (inp_maybe_overlap || out_maybe_overlap) {
    flatten_full_indices =
        // TODO: should we symint-ify arange? Need SymScalar.
        at::arange(
            0,
            base_size.guard_int(__FILE__, __LINE__),
            grad.options().dtype(at::kLong));
  }

  // Step (2): use output geometry to scatter gradients into storage
  if (out_maybe_overlap) {
    // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
    auto out_indices = flatten_full_indices->as_strided_symint(
        out_sizes_, out_strides_, out_effective_offset);
    storage.index_add_(0, out_indices.reshape(-1), grad.reshape(-1));
  } else {
    // assume that new tensors have 0 storage offset
    storage.as_strided_symint(out_sizes_, out_strides_, out_effective_offset)
        .copy_(grad);
  }

  // Step (3): if input tensor has overlapping memory, divide scattered gradient
  //           at storage[i] by the number of times i shows up in input geometry
  if (inp_maybe_overlap) {
    auto count = at::zeros_like(storage, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
    auto inp_indices =
        // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
        flatten_full_indices
            ->as_strided_symint(inp_sizes_, inp_strides_, inp_effective_offset)
            .reshape(-1);
    count.index_add_(
        0, inp_indices, at::ones({1}, grad.options()).expand_as(inp_indices));
    storage.div_(count); // this will give nan outside visible range
  }
  // Step (4): return as_strided view of the storage tensor with input geometry
  return storage.as_strided_symint(
      inp_sizes, inp_strides, inp_effective_offset);
}

Tensor as_strided_scatter_backward(
    const Tensor& grad,
    const TensorGeometry& input_geometry,
    const TensorGeometry& src_geometry,
    c10::SymIntArrayRef sizes,
    c10::SymIntArrayRef strides,
    std::optional<c10::SymInt> storage_offset) {
  // Note [as_strided_scatter backward support]
  // as_strided_scatter handling for autograd is a beast, and is non-trivial to
  // implement for arbitrarily strided inputs. Most uses for as_strided with
  // functionalization only care about the contiguous case anyway, So for now
  // this is not implemented. When autograd is being used, we ban non-contiguous
  // inputs. We can assume that the input was a contiguous tensor. Also, we'll
  // take the perf hit and contiguify grad for now.
  auto grad_ = grad.contiguous();
  auto grad_slice = grad_.as_strided_symint(sizes, strides, storage_offset);
  auto result_buffer = grad_.new_zeros_symint(input_geometry.sym_sizes());
  auto result = result_buffer.as_strided_symint(
      input_geometry.sym_sizes(), input_geometry.sym_strides());
  auto result_slice = result_buffer.as_strided_symint(
      sizes, strides, std::move(storage_offset));
  result_slice.copy_(grad_slice);
  return result;
}

std::tuple<Tensor, Tensor> atan2_backward(
    const Tensor& grad,
    const Tensor& self,
    const Tensor& other,
    std::array<bool, 2> output_mask) {
  if (!grad.defined()) {
    return std::tuple<Tensor, Tensor>{Tensor(), Tensor()};
  }
  auto recip = (self * self + other * other).reciprocal();
  return std::tuple<Tensor, Tensor>{
      output_mask[0] ? grad * other * recip : Tensor(),
      output_mask[1] ? grad * -self * recip : Tensor()};
}

Tensor gelu_double_backward(
    const Tensor& ggI,
    const Tensor& gO,
    const Tensor& input,
    c10::string_view approximate) {
  // if (at::native::get_gelutype_enum(approximate) ==
  // at::native::GeluType::Tanh) {
  if (approximate == "tanh") {
    constexpr auto kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
    constexpr auto kKappa = 0.044715;

    auto inner = kBeta * (input + kKappa * pow(input, 3));
    auto tanh_inner = tanh(inner);
    auto sech_inner = 1 / cosh(inner);

    auto f = 0.5 * input;
    auto g = 1 - tanh_inner * tanh_inner;
    auto h = kBeta * (1 + 3 * kKappa * input * input);

    auto f_prime_gh = 0.5 * g * h;

    auto g_prime = (2 * sech_inner) * (-sech_inner * tanh_inner) * h;
    auto g_prime_fh = f * h * g_prime;

    auto h_prime = 6 * kKappa * input * kBeta;
    auto h_prime_fg = f * g * h_prime;

    // left_derivative = f_prime_gh
    // right_derivative = f_prime_gh + g_prime_fh + h_prime_fg
    // dgrad_dX = left_derivative + right_derivative
    auto gI = ggI * gO * (2 * f_prime_gh + g_prime_fh + h_prime_fg);
    return gI;
  } else {
    constexpr auto kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5;
    auto input_sq = input * input;
    auto pdf = kBeta * at::exp(-0.5 * input_sq);
    auto dgrad_dInput = 2 * pdf - input_sq * pdf;
    auto gI = ggI * gO * dgrad_dInput;
    return gI;
  }
}

Tensor elu_double_backward(
    const Tensor& grad,
    const Tensor& grad_output,
    const Scalar& alpha,
    const Scalar& scale,
    const Scalar& input_scale,
    bool is_result,
    const Tensor& self_or_result) {
  if (is_result) {
    return grad * grad_output * input_scale *
        (self_or_result < 0).type_as(grad);
  } else {
    return at::elu_backward(
               grad * grad_output * input_scale,
               alpha,
               scale,
               input_scale,
               is_result,
               self_or_result) *
        (self_or_result < 0).type_as(grad);
  }
}

Tensor slice_backward_wrapper(
    const at::Tensor& grad,
    const c10::SymIntArrayRef& input_sizes,
    int64_t dim,
    std::optional<c10::SymInt> start,
    std::optional<c10::SymInt> end,
    c10::SymInt step) {
  auto start_val = start.has_value() ? start.value() : 0;
  auto end_val = end.has_value() ? end.value() : INT64_MAX;

  return slice_backward_symint(
      grad,
      input_sizes,
      dim,
      std::move(start_val),
      std::move(end_val),
      std::move(step));
}

std::tuple<Tensor, Tensor, Tensor> linalg_svd_jvp(
    const Tensor& dA,
    const Tensor& U_,
    const Tensor& S,
    const Tensor& Vh_,
    const bool full_matrices) {
  at::NoTF32Guard disable_tf32;
  // See svd_backward for the derivation
  // With sym(X) = X + X^H, we implement
  // dU = U (sym(dX S) / E + i Im(diag(dX)) / (2S))
  // if m > n
  //   dU = [dU for m == n] + (I_m - UU^H) dA V S^{-1}
  // dS = Re(diag(dP))
  // dV = V (sym(S dX) / E - i Im(diag(dX)) / (2S))
  // if m < n
  //   dV = [dV for m == n] + (I_n - VV^H) (dA)^H U S^{-1}
  // dVh = dV^H
  // with dP = U^H dA V
  //      dX = dP - dS
  //      E_{jk} = S_k^2 - S_j^2 if j != k
  //               1             otherwise

  // Checks compute_uv=true
  TORCH_INTERNAL_ASSERT(U_.dim() >= 2 && Vh_.dim() >= 2);

  const auto is_complex = dA.is_complex();
  const auto m = dA.size(-2);
  const auto n = dA.size(-1);
  const auto k = S.size(-1);

  const auto U = full_matrices ? U_.narrow(-1, 0, k) : U_;
  const auto Vh = full_matrices ? Vh_.narrow(-2, 0, k) : Vh_;
  const auto V = Vh.mH();

  // dP = U^H dA V
  auto dP = m >= n ? at::matmul(U.mH(), at::matmul(dA, V))
                   : at::matmul(at::matmul(U.mH(), dA), V);

  auto dS =
      is_complex ? at::real(dP.diagonal(0, -2, -1)) : dP.diagonal(0, -2, -1);

  // dX = dP - dS
  dP = dP - dS.diag_embed();

  auto E = [&S] {
    const auto S2 = S * S;
    auto ret = S2.unsqueeze(-2) - S2.unsqueeze(-1);
    // Any number a != 0 would, as we are just going to use it to compute 0 / a
    // later on
    ret.diagonal(0, -2, -1).fill_(1);
    return ret;
  }();

  const auto sym = [](const Tensor& X) { return X + X.mH(); };

  // diag(dP) / (2S)
  auto diagdP2S = is_complex ? dP.diagonal(0, -2, -1).div(2. * S) : Tensor{};

  // dU = U (sym(dP S) / E) + i Im(diag(dP)) / (2S)
  auto dU = [&] {
    auto dUaux = sym(dP * S.unsqueeze(-2)) / E;
    if (is_complex) {
      dUaux = dUaux + diagdP2S.diag_embed();
    }
    return at::matmul(U, dUaux);
  }();
  if (m > n) {
    // dU += (I_m - UU^H) dA V S^{-1}
    const auto dAVSinv = at::matmul(dA, V / S.unsqueeze(-2));
    dU = dU + dAVSinv - at::matmul(U, at::matmul(U.mH(), dAVSinv));

    // To "fix" the full_matrices case (the full_matrices case should not be
    // differentiable...)
    if (full_matrices) {
      auto shape = dU.sizes().vec();
      shape.end()[-1] = m - n;
      dU = at::cat({dU, dU.new_zeros(shape)}, /*dim=*/-1);
    }
  }

  // dVh = -sym(S dP) / E + i Im(diag(dP)) / (2S)
  // Perf: We negate the S as it's the smallest tensor in the equation
  auto dVh = [&] {
    auto dVhaux = sym(dP * (-S).unsqueeze(-1)) / E;
    if (is_complex) {
      dVhaux = dVhaux + diagdP2S.diag_embed();
    }
    return at::matmul(dVhaux, Vh);
  }();
  if (m < n) {
    // dVh += S^{-1} U^H dA (I_n - VV^H)
    const auto UHdASinv = at::matmul(U.mH() / S.unsqueeze(-1), dA);
    dVh = dVh + UHdASinv - at::matmul(at::matmul(UHdASinv, V), Vh);

    // To "fix" the full_matrices case (the full_matrices case should not be
    // differentiable...)
    if (full_matrices) {
      auto shape = dVh.sizes().vec();
      shape.end()[-2] = n - m;
      dVh = at::cat({dVh, dVh.new_zeros(shape)}, /*dim=*/-2);
    }
  }

  return std::make_tuple(std::move(dU), std::move(dS), std::move(dVh));
}

Tensor svd_backward(
    const Tensor& gU,
    const Tensor& gS,
    const Tensor& gVh,
    const Tensor& U,
    const Tensor& S,
    const Tensor& Vh) {
  at::NoTF32Guard disable_tf32;
  // Throughout both the real and complex case we assume A has distinct singular
  // values. Furthermore, if A is rectangular or complex, we assume it's
  // full-rank.
  //
  //
  // The real case (A \in R)
  // See e.g. https://j-towns.github.io/papers/svd-derivative.pdf
  //
  // Denote by skew(X) = X - X^T, and by A o B the coordinatewise product, then
  // if m == n
  //   gA = U [(skew(U^T gU) / E)S + S(skew(V^T gV) / E) + I o gS ]V^T
  // where E_{jk} = S_k^2 - S_j^2 if j != k and 1 otherwise
  //
  // if m > n
  //   gA = [term in m == n] + (I_m - UU^T)gU S^{-1} V^T
  // if m < n
  //   gA = [term in m == n] + U S^{-1} (gV)^T (I_n - VV^T)
  //
  //
  // The complex case (A \in C)
  // This one is trickier because the svd is not locally unique.
  // Denote L = diag(e^{i\theta_k}), then we have that if A = USV^H, then (UL,
  // S, VL) is another valid SVD decomposition of A as A = ULS(VL)^H =
  // ULSL^{-1}V^H = USV^H, since L, S and L^{-1} commute, since they are all
  // diagonal.
  //
  // Assume wlog that n >= k in what follows, as otherwise we could reason about
  // A^H. Denote by St_k(C^n) = {A \in C^{n,k} | A^H A = I_k} the complex
  // Stiefel manifold. What this invariance means is that the svd decomposition
  // is not a map svd: C^{n x k} -> St_k(C^n) x R^n x St_k(C^k) (where St_k(C^k)
  // is simply the unitary group U(k)) but a map svd: C^{n x k} -> M x R^n where
  // M is the manifold given by quotienting St_k(C^n) x U(n) by the action (U,
  // V) -> (UL, VL) with L as above. Note that M is a manifold, because the
  // action is free and proper (as U(1)^k \iso (S^1)^k is compact). For this
  // reason, pi : St_k(C^n) x U(n) -> M forms a principal bundle.
  //
  // To think about M, consider the case case k = 1. The, we have the bundle
  // pi : St_1(C^n) x U(1) -> M
  // now, St_1(C^n) are just vectors of norm 1 in C^n. That's exactly the sphere
  // of dimension 2n-1 in C^n \iso R^{2n} S^{2n-1} = { z \in C^n | z^H z = 1}.
  // Then, in this case, we're quotienting out U(1) completely, so we get that
  // pi : S^{2n-1} x U(1) -> CP(n-1)
  // where CP(n-1) is the complex projective space of dimension n-1.
  // In other words, M is just the complex projective space, and pi is (pretty
  // similar to) the usual principal bundle from S^{2n-1} to CP(n-1). The case k
  // > 1 is the same, but requiring a linear independence condition between the
  // vectors from the different S^{2n-1} or CP(n-1).
  //
  // Note that this is a U(1)^k-bundle. In plain words, this means that the
  // fibres of this bundle, i.e. pi^{-1}(x) for x \in M are isomorphic to U(1) x
  // ... x U(1). This is obvious as, if pi(U,V) = x, pi^{-1}(x) = {(U
  // diag(e^{i\theta}), V diag(e^{i\theta})) | \theta \in R^k}
  //            = {(U diag(z), V diag(z)) | z \in U(1)^k}
  // since U(1) = {z \in C | |z| = 1}.
  //
  // The big issue here is that M with its induced metric is not locally
  // isometric to St_k(C^n) x U(k). [The why is rather technical, but you can
  // see that the horizontal distribution is not involutive, and hence
  // integrable due to Frobenius' theorem] What this means in plain words is
  // that, no matter how we choose to return the U and V from the SVD, we won't
  // be able to simply differentiate wrt. U and V and call it a day. An example
  // of a case where we can do this is when performing an eigendecomposition on
  // a real matrix that happens to have real eigendecomposition. In this case,
  // even though you can rescale the eigenvectors by any real number, you can
  // choose them of norm 1 and call it a day. In the eigenvector case, we are
  // using that you can isometrically embed S^{n-1} into R^n. In the svd case,
  // we need to work with the "quotient manifold" M explicitly, which is
  // slightly more technically challenging.
  //
  // Since the columns of U and V are not uniquely defined, but are
  // representatives of certain classes of equivalence which represent elements
  // M, the user may not depend on the particular representative that we return
  // from the SVD. In particular, if the loss function depends on U or V, it
  // must be invariant under the transformation (U, V) -> (UL, VL) with L =
  // diag(e^{i\theta})), for every \theta \in R^k. In more geometrical terms,
  // this means that the loss function should be constant on the fibres, or, in
  // other words, the gradient along the fibres should be zero. We may see this
  // by checking that the gradients as element in the tangent space T_{(U,
  // V)}(St(n,k) x U(k)) are normal to the fibres. Differentiating the map (U,
  // V) -> (UL, VL), we see that the space tangent to the fibres is given by
  // Vert_{(U, V)}(St(n,k) x U(k)) = { i[U, V]diag(\theta) | \theta in R^k}
  // where [U, V] denotes the vertical concatenation of U and V to form an (n+k,
  // k) matrix. Then, solving <i[U,V]diag(\theta), [S, T]> = 0 for two matrices
  // S, T \in T_{(U, V)}(St(n,k) x U(k)) where <A, B> = Re tr(A^H B) is the
  // canonical (real) inner product in C^{n x k} we get that the function is
  // invariant under action of U(1)^k iff Im(diag(U^H gU + V^H gV)) = 0
  //
  // Using this in the derviaton for the forward AD, one sees that, with the
  // notation from those notes Using this and writing sym(X) = X + X^H, we get
  // that the forward AD for SVD in the complex case is given by dU = U (sym(dX
  // S) / E + i Im(diag(dX)) / (2S)) if m > n
  //   dU = [dU for m == n] + (I_m - UU^H) dA V S^{-1}
  // dS = Re(diag(dP))
  // dV = V (sym(S dX) / E - i Im(diag(dX)) / (2S))
  // if m < n
  //   dV = [dV for m == n] + (I_n - VV^H) (dA)^H U S^{-1}
  // dVh = dV^H
  // with dP = U^H dA V
  //      dX = dP - dS
  //      E_{jk} = S_k^2 - S_j^2 if j != k
  //               1             otherwise
  //
  // Similarly, writing skew(X) = X - X^H
  // the adjoint wrt. the canonical metric is given by
  // if m == n
  //   gA = U [((skew(U^H gU) / E) S + i Im(diag(U^H gU)) / S + S ((skew(V^H gV)
  //   / E)) + I o gS] V^H
  // if m > n
  //   gA = [term in m == n] + (I_m - UU^H)gU S^{-1} V^H
  // if m < n
  //   gA = [term in m == n] + U S^{-1} (gV)^H (I_n - VV^H)
  // where we have used that Im(diag(U^H gU)) = - Im(diag(V^h gV)) to group the
  // diagonal imaginary terms into one that just depends on U^H gU.

  // Checks compute_uv=true
  TORCH_INTERNAL_ASSERT(U.dim() >= 2 && Vh.dim() >= 2);

  // Trivial case
  if (!gS.defined() && !gU.defined() && !gVh.defined()) {
    return {};
  }

  const auto m = U.sym_size(-2);
  const auto n = Vh.sym_size(-1);

  // Optimisation for svdvals: gA = U @ diag(gS) @ Vh
  if (!gU.defined() && !gVh.defined()) {
    return m >= n ? at::matmul(U, gS.unsqueeze(-1) * Vh)
                  : at::matmul(U * gS.unsqueeze(-2), Vh);
  }
  // At this point, at least one of gU, gVh is defined

  const bool is_complex = U.is_complex();
  const auto skew = [](const Tensor& A) { return A - A.mH(); };
  const auto UhgU = gU.defined() ? skew(at::matmul(U.mH(), gU)) : Tensor{};
  const auto VhgV = gVh.defined() ? skew(at::matmul(Vh, gVh.mH())) : Tensor{};

  // Check for the invariance of the loss function, i.e.
  // Im(diag(U^H gU)) + Im(diag(V^H gV)) = 0
  if (is_complex) {
    const auto imdiag_UhgU =
        gU.defined() ? at::imag(UhgU.diagonal(0, -2, -1)) : at::zeros_like(S);
    const auto imdiag_VhgV =
        gVh.defined() ? at::imag(VhgV.diagonal(0, -2, -1)) : at::zeros_like(S);
    // Rather lax atol and rtol, as we don't want false positives
    TORCH_CHECK(
        at::allclose(imdiag_UhgU, -imdiag_VhgV, /*rtol=*/1e-2, /*atol=*/1e-2),
        "svd_backward: The singular vectors in the complex case are specified up to multiplication "
        "by e^{i phi}. The specified loss function depends on this phase term, making "
        "it ill-defined.");
  }

  // gA = ((U^H gU) / E) S +  S (((V^H gV) / E) + I o (gS + diag(U^H gU) / (2 *
  // S))
  Tensor gA = [&] {
    // ret holds everything but the diagonal of gA
    auto ret = [&] {
      const auto E = [&S] {
        const auto S2 = S * S;
        auto ret = S2.unsqueeze(-2) - S2.unsqueeze(-1);
        // Any number a != 0 would, as we are just going to use it to compute 0
        // / a later on
        ret.diagonal(0, -2, -1).fill_(1);
        return ret;
      }();

      if (gU.defined()) {
        if (gVh.defined()) {
          return (UhgU * S.unsqueeze(-2) + S.unsqueeze(-1) * VhgV) / E;
        } else {
          return (UhgU / E) * S.unsqueeze(-2);
        }
      } else { // gVh.defined();
        return S.unsqueeze(-1) * (VhgV / E);
      }
    }();
    // Fill the diagonal
    if (gS.defined()) {
      ret = ret + gS.diag_embed();
    }
    if (is_complex && gU.defined() && gVh.defined()) {
      ret = ret + (UhgU.diagonal(0, -2, -1) / (2. * S)).diag_embed();
    }
    return ret;
  }();

  if (m > n && gU.defined()) {
    // gA = [UgA + (I_m - UU^H)gU S^{-1}]V^H
    gA = at::matmul(U, gA);
    const auto gUSinv = gU / S.unsqueeze(-2);
    gA = gA + gUSinv - at::matmul(U, at::matmul(U.mH(), gUSinv));
    gA = at::matmul(gA, Vh);
  } else if (m < n && gVh.defined()) {
    //   gA = U[gA V^H + S^{-1} (gV)^H (I_n - VV^H)]
    gA = at::matmul(gA, Vh);
    const auto SinvgVh = gVh / S.unsqueeze(-1);
    gA = gA + SinvgVh - at::matmul(at::matmul(SinvgVh, Vh.mH()), Vh);
    gA = at::matmul(U, gA);
  } else {
    // gA = U gA V^H
    gA = m >= n ? at::matmul(U, at::matmul(gA, Vh))
                : at::matmul(at::matmul(U, gA), Vh);
  }

  return gA;
}

Tensor linalg_eig_backward(
    const Tensor& gL,
    const Tensor& gV,
    const Tensor& L,
    const Tensor& V,
    const bool is_hermitian,
    const bool symeig_eigenvectors) {
  at::NoTF32Guard disable_tf32;
  // https://arxiv.org/pdf/1701.00392.pdf Eq 4.77
  // For A = VLV^{-1}, denoting the gradients gA, gV and gL, we have
  // gA = V^{-H}(diag_embed(gL) + (V^H gV -V^HV diag(real(V^H gV))) / E*)V^H
  // Where:
  //   - E_{ij} = L_j - L_i if i != j
  //              1         otherwise
  //   - diag_embed takes a vector into a diagonal matrix
  //   - diag zeroes out elements outside of the diagonal

  // Note: the term '-V^HV diag(real(V^H gV))' comes from the fact that the
  // eigenvalue decomposition is returned with eigenvectors normalized to have
  // norm one.

  // Note: The Hermitian case is a simplification of this formula using that
  // V^{-1} = V^H and that L is real

  // This check just can be triggered in the backwards of torch.symeig
  TORCH_CHECK(
      symeig_eigenvectors,
      "linalg_eig_backward: torch.symeig(A, eigenvectors=False) is not differentiable. ",
      "Use torch.linalg.eigvalsh(A) instead.");

  // Trivial case
  if (!gL.defined() && !gV.defined()) {
    return {};
  }

  // Shortcut for linalg.eigvals/eigvalsh
  // Compute V^-H gL V^H
  if (!gV.defined()) {
    if (is_hermitian) {
      return at::matmul(V * gL.unsqueeze(-2), V.mH());
    } else {
      return at::linalg_solve(V.mH(), gL.unsqueeze(-1) * V.mH());
    }
  }
  auto VhgV = at::matmul(V.mH(), gV);
  const auto diag_VhgV = VhgV.diagonal(0, -2, -1);

  if (V.is_complex() && !at::isTensorSubclassLike(diag_VhgV)) {
    // Check invariance of the loss function wrt the transformation
    // V -> V * e^{i\phi} for an arbitrary phi in RR^n
    const auto imdiag_VhgV = at::imag(diag_VhgV);
    TORCH_CHECK(
        at::allclose(
            imdiag_VhgV,
            at::zeros_like(imdiag_VhgV),
            /*rtol=*/1e-2,
            /*atol=*/1e-2),
        is_hermitian ? "linalg_eigh_backward" : "linalg_eig_backward",
        ": The eigenvectors in the complex case are specified up to multiplication ",
        "by e^{i phi}. The specified loss function depends on this quantity, so it is ill-defined.");
  }

  if (is_hermitian) {
    // Project onto the tangent space at the identity of U(n), that is, the
    // skew-Hermitian matrices
    VhgV = 0.5 * (VhgV - VhgV.mH());
  } else {
    // Project onto the tangent space at V^H V of complex matrices with columns
    // of norm 1
    VhgV = VhgV - at::matmul(V.mH(), V * at::real(diag_VhgV).unsqueeze(-2));
  }

  auto gA = [&, VhgV = std::move(VhgV)] {
    auto Econj = [&L] {
      auto Lconj = L.conj();
      auto ret = Lconj.unsqueeze(-2) - Lconj.unsqueeze(-1);
      ret.diagonal(0, -2, -1).fill_(1.);
      return ret;
    }();

    auto ret = VhgV.div_(Econj);

    if (gL.defined()) {
      // For CompositeCompliance, if `gL` is subclass but `ret`
      // is a regular Tensor, then use out-of-place version of diagonal
      // copy aka `diagonal_scatter`.
      if (at::isTensorSubclassLike(gL)) {
        ret = ret.diagonal_scatter(gL, 0, -2, -1);
      } else {
        ret.diagonal(0, -2, -1).copy_(gL);
      }
    }
    return ret;
  }();

  // Conjugate by V^{-H}
  if (is_hermitian) {
    return at::matmul(V, at::matmul(gA, V.mH()));
  } else {
    return at::linalg_solve(V.mH(), at::matmul(gA, V.mH()));
  }
}

std::tuple<Tensor, Tensor> linalg_eig_jvp(
    const Tensor& dA,
    const Tensor& L,
    const Tensor& V,
    const bool is_hermitian) {
  at::NoTF32Guard disable_tf32;
  // https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
  // see also https://arxiv.org/pdf/1701.00392.pdf Eqs. (4.60) and (4.63)
  // Note that neither of the formulas in these pdfs are correct, as they do not
  // assume that the eigenvectors are of unit norm. As such, they are missing
  // the diagonal term in dV dL = diag(dP) dV = dX - V Re(diag V^H dX)) where dP
  // = V^{-1} dA V dX = V ((dP - diag(dP)) / E) E_{ij} = L_j - L_i if i != j
  //          1         otherwise

  // Precondition: if is_hermitian == true, then dA is Hermitian
  const auto to_complex = [](const Tensor& A) {
    return A.to(c10::toComplexType(A.scalar_type()));
  };

  const auto dP = is_hermitian
      ? at::matmul(at::matmul(V.mH(), dA), V)
      : at::linalg_solve(V, at::matmul(to_complex(dA), V));
  auto dL = is_hermitian && dA.is_complex() ? at::real(dP.diagonal(0, -2, -1))
                                            : dP.diagonal(0, -2, -1);
  auto dV = [&dP, &V, &L, is_hermitian] {
    auto dX = [&] {
      auto ret = dP / (L.unsqueeze(-2) - L.unsqueeze(-1));
      ret.diagonal(0, -2, -1).zero_();
      ret = at::matmul(V, ret);
      return ret;
    }();

    if (is_hermitian) {
      return dX;
    } else {
      return dX -
          V *
          at::real(at::matmul(V.mH(), dX).diagonal(0, -2, -1)).unsqueeze(-2);
    }
  }();
  return std::make_pair(std::move(dL), std::move(dV));
}

Tensor linalg_lstsq_jvp(
    const Tensor& A,
    const Tensor& B,
    const Tensor& dA,
    const Tensor& dB) {
  at::NoTF32Guard disable_tf32;
  auto pinvA = at::linalg_pinv(A);
  auto dpinvA = pinv_jvp(A, pinvA, dA);
  auto dX = dpinvA.matmul(B) + pinvA.matmul(dB);
  return dX;
}

std::tuple<Tensor, Tensor> linalg_lstsq_backward(
    const Tensor& gX_,
    const Tensor& A,
    const Tensor& B_,
    const std::array<bool, 2>& grad_input_mask) {
  at::NoTF32Guard disable_tf32;
  auto A_requires_grad = grad_input_mask[0];
  auto B_requires_grad = grad_input_mask[1];
  if (!gX_.defined() || (!A_requires_grad && !B_requires_grad)) {
    return {};
  }

  const bool vector_case = at::native::linalg_solve_is_vector_rhs(A, B_);
  const auto vector_to_matrix = [vector_case](const Tensor& X) {
    return vector_case ? X.unsqueeze(-1) : X;
  };
  const auto matrix_to_vector = [vector_case](const Tensor& X) {
    return vector_case ? X.squeeze(-1) : X;
  };

  auto gX = vector_to_matrix(gX_);
  auto B = vector_to_matrix(B_);
  Tensor pinvA = at::linalg_pinv(A);
  Tensor A_grad, B_grad;
  if (A_requires_grad) {
    auto pinvA_grad = gX.matmul(B.mH());
    A_grad = pinv_backward(pinvA_grad, pinvA, A);
  }

  if (B_requires_grad) {
    // Equivalent to
    // B_grad = std::get<0>(at::linalg_lstsq(A.mH(), gX, rcond, driver));
    // but we avoid this approach as `gelsy` is non-deterministic
    B_grad = matrix_to_vector(pinvA.mH().matmul(gX));
  }

  return std::make_tuple(A_grad, B_grad);
}

std::tuple<Tensor, Tensor> linalg_qr_jvp(
    const Tensor& dA,
    const Tensor& Q,
    const Tensor& R,
    const c10::string_view mode) {
  // dA = dQR + QdR
  //
  // Case m >= n
  // We can put dQ in terms of dR
  // dQ = dAR^{-1} - QdRR^{-1}
  // Then we have
  // Q^H dA R^{-1} = Q^HdQ + dRR^{-1}
  // where Q^HdQ is skew Hermitian and dRR^{-1} is upper triangular
  // Define sym(X) = X + X^H
  // sym(dRR^{-1}) = sym(Q^H dA R^{-1})
  // and define syminv(X) = triu(X) - 0.5 * diag(X) the inverse of
  // sym : Triu(k, diag \in \mathbb{R}) -> Her(k) to give
  // dR = syminv(sym(Q^H dA R^{-1}))R
  //
  // Case m < n
  // Put dR as a function of dQ
  // dR = Q^H dA - Q^H dQ R
  // Let X_1 be the main m x m submatrix of a matrix X \in C^{m x n}
  // Q^H A_1 R_1^{-1} = Q^H dQ + dR_1 R_1^{-1}
  // Define trilIm(X) = X.tril(-1) + i * Im diag(X)
  // trilIm(Q^H dQ) = trilIm(Q^H A_1 R_1^{-1})
  // and define trilIminv(X) = X - X^H - i*Im diag(X). This is the inverse of
  // trilIm : Skew_C(m) -> Tril(m, imaginary diag)
  // Note that it is just the inverse when the inputs are skew-Hermitian, not
  // necessarily when the inputs are arbitrary matrices. We then get dQ = Q
  // trilImInv(trilIm(Q^H A_1 R_1^{-1}))
  at::NoTF32Guard disable_tf32;

  auto [compute_q, reduced] = at::native::_parse_qr_mode(mode);

  TORCH_CHECK(
      compute_q,
      "The derivative of linalg.qr depends on Q, which is not computed when "
      "mode='r'. Please use linalg.qr(A, mode='reduced') if you are "
      "going to differentiate through linalg.qr.");
  auto m = dA.size(-2);
  auto n = dA.size(-1);

  TORCH_CHECK(
      reduced || m <= n,
      "The QR decomposition is not differentiable when "
      "mode='complete' and nrows > ncols.");
  if (m >= n) {
    const auto sym = [](const Tensor& X) { return X + X.mH(); };
    const auto syminv = [](const Tensor& X) {
      auto ret = X.triu();
      ret.diagonal(0, -2, -1).mul_(0.5);
      return ret;
    };
    auto dARinv =
        at::linalg_solve_triangular(R, dA, /*upper=*/true, /*left=*/false);
    auto dR = syminv(sym(Q.mH().matmul(dARinv)));
    auto dQ = dARinv - Q.matmul(dR);
    dR = dR.matmul(R);
    return std::make_tuple(std::move(dQ), std::move(dR));
  } else {
    const auto trilim = [](const Tensor& X) {
      if (X.is_complex()) {
        auto ret = X.tril();
        at::real(ret.diagonal(0, -2, -1)).zero_();
        return ret;
      } else {
        return X.tril(-1);
      }
    };
    const auto triliminv = [](const Tensor& X) {
      if (X.is_complex()) {
        auto ret = X - X.mH();
        ret.diagonal(0, -2, -1).mul_(0.5);
        return ret;
      } else {
        return X - X.mT();
      }
    };

    auto QHdA = Q.mH().matmul(dA);
    auto QHdA1Rinv = at::linalg_solve_triangular(
        R.narrow(-1, 0, m),
        QHdA.narrow(-1, 0, m),
        /*upper=*/true,
        /*left=*/false);
    auto dQ = triliminv(trilim(QHdA1Rinv));
    auto dR = QHdA - dQ.matmul(R);
    dQ = Q.matmul(dQ);
    return std::make_tuple(std::move(dQ), std::move(dR));
  }
}

Tensor linalg_qr_backward(
    const Tensor& gQ,
    const Tensor& gR,
    const Tensor& Q,
    const Tensor& R,
    const c10::string_view mode) {
  // Nb. We won't be too formal below, as writing this proof formally is a pain
  // We'll link here a formal writing of all this at some point in the future
  //
  // Case m >= n
  // dQ = dAR^{-1} - Qsyminv(sym(Q^H dA R^{-1}))
  // dR = syminv(sym(Q^H dA R^{-1}))R
  //
  // With the notation from the JVP formula, the only two computations that we
  // need are syminv*(R) = 0.5 * (R.triu() + R.triu()^H - Re diag(R)) sym*(X) =
  // 2 * X Using these, after a few simplifications we get that gA = (gQ +
  // syminvadj(triu(gR R^H - Q^H gQ)))R^{-H}
  //
  // Case m < n
  // dR = Q^H dA - Q^H dQ R
  // dQ = Q trilImInv(trilIm(Q^H A_1 R_1^{-1}))
  //
  // In this case trilIm*(X) = X (it's the trivial embedding)
  // while trilImInv*(X) = tril(Y) - 0.5 * diag(Y)
  // with Y = X - X^H
  //
  // We also have that if X \in C^{m, n} an dpi(X) = X_1,
  // projects X into its leading m x m submatrix,
  // pi*(X) = cat(X, 0_{m,n-m}, dim=-1)
  //
  // Using this, we get that
  // gA = QgR + pi*(Q trilImInv*(Q^H gQ - gR R^H)R_1^{-H})
  at::NoTF32Guard disable_tf32;

  auto [compute_q, reduced] = at::native::_parse_qr_mode(mode);

  TORCH_CHECK(
      compute_q,
      "The derivative of linalg.qr depends on Q, which is not computed when "
      "mode='r'. Please use linalg.qr(A, mode='reduced') if you are "
      "going to differentiate through linalg.qr.");

  auto m = Q.sym_size(-2);
  auto n = R.sym_size(-1);

  TORCH_CHECK(
      reduced || m <= n,
      "The QR decomposition is not differentiable when "
      "mode='complete' and nrows > ncols.");

  if (!gQ.defined() && !gR.defined()) {
    return {};
  }

  Tensor gA;
  if (gQ.defined()) {
    if (gR.defined()) {
      gA = gR.matmul(R.mH()) - Q.mH().matmul(gQ);
    } else {
      gA = -Q.mH().matmul(gQ);
    }
  } else {
    gA = gR.matmul(R.mH());
  }
  if (m >= n) {
    const auto syminvadj = [](const Tensor& X) {
      auto ret = X + X.mH();
      at::real(ret.diagonal(0, -2, -1)).mul_(0.5);
      return ret;
    };
    gA = Q.matmul(syminvadj(gA.triu()));
    if (gQ.defined()) {
      gA = gA + gQ;
    }
    gA = at::linalg_solve_triangular(
        R.mH(), gA, /*upper*/ false, /*left*/ false);
    return gA;
  } else {
    auto trilImInvAdjSkew = [](const Tensor& X) {
      auto ret = (X - X.mH()).tril();
      if (X.is_complex()) {
        at::imag(ret.diagonal(0, -2, -1)).mul_(0.5);
      }
      return ret;
    };
    gA = Q.matmul(trilImInvAdjSkew(-gA));
    gA = at::linalg_solve_triangular(
        R.narrow_symint(-1, 0, m).mH(), gA, /*upper*/ false, /*left*/ false);
    auto shape = R.sym_sizes().vec();
    shape.end()[-1] = n - m;
    gA = at::cat({gA, gA.new_zeros_symint(shape)}, /*dim=*/-1);
    if (gR.defined()) {
      gA = gA + Q.matmul(gR);
    }
    return gA;
  }
}

// Based on:
//
// Mathias, Roy.
// A Chain Rule for Matrix Functions and Applications.
// SIAM J. Matrix Anal. Appl. 17 (1996): 610-620.

template <typename func_t>
Tensor differential_analytic_matrix_function(
    const Tensor& self,
    const Tensor& grad,
    const func_t& matrix_function,
    const bool adjoint // Choose between forward (adjoint=false) or backward AD
                       // (adjoint=true)
) {
  // Given an analytic matrix function, this computes the differential (forward
  // AD) or the adjoint of the differential (backward AD)
  auto A = adjoint ? self.transpose(-2, -1).conj() : self;
  auto meta_grad_sizes = A.sym_sizes().vec();
  meta_grad_sizes[A.dim() - 2] *= 2;
  meta_grad_sizes[A.dim() - 1] *= 2;

  auto n = A.sym_size(-1);
  Tensor meta_grad;
  // For Composite Compliance, we can't copy a Subclass into a Regular Tensor,
  // so we use out-of-place ops with equivalent output.
  // NOTE: We can't use `new_zeros` directly as both `A` and `grad` can
  // be Tensor Subclass and we don't want to make assumption about which
  // one to choose for creating output buffer.
  // eg. if both are BatchedTensor at different level.
  if (areAnyTensorSubclassLike({A, grad})) {
    meta_grad = at::cat(
        {at::cat({A, grad}, -1),
         at::cat({at::zeros_like(A), std::move(A)}, -1)},
        -2);
  } else {
    meta_grad = at::zeros_symint(meta_grad_sizes, grad.options());
    meta_grad.narrow_symint(-2, 0, n).narrow_symint(-1, 0, n).copy_(A);
    meta_grad.narrow_symint(-2, n, n).narrow_symint(-1, n, n).copy_(A);
    meta_grad.narrow_symint(-2, 0, n).narrow_symint(-1, n, n).copy_(grad);
  }

  return matrix_function(meta_grad).narrow_symint(-2, 0, n).narrow_symint(
      -1, n, n);
}

Tensor linalg_matrix_exp_differential(
    const Tensor& self,
    const Tensor& grad,
    bool adjoint) {
  at::NoTF32Guard disable_tf32;

  return differential_analytic_matrix_function(
      self, grad, at::linalg_matrix_exp, /* adjoint */ adjoint);
}

template <typename F1, typename F2, typename... Ts>
Tensor masked_fmap(
    const Tensor& mask,
    const F1& f1,
    const F2& f2,
    const Tensor& t,
    const Ts&... ts) {
  // This function takes two functions f1 and f2 and a (variadic) list of
  // tensors, and creates a new tensor of the same shape as the first element of
  // the list of tensors by applying the function f1 to the tensors for which
  // the mask is true and f2 to the tensors for which the mask is false This
  // function is used when we have a formula that works for, say, all
  // non-singular inputs and another one for when the inputs are singular. See
  // for example det_backward

  // Precondition for the n == 0 case to make sense
  TORCH_INTERNAL_ASSERT(t.sym_numel() != 0);
  auto t_masked = t.index({mask});
  auto n = t_masked.sym_numel();
  if (n == t.sym_numel()) {
    return f1(t, ts...);
  } else if (n == 0) {
    return f2(t, ts...);
  } else {
    // Equivalent to
    // ret = torch.empty_like(t)
    // ret[mask] = f1(t1[mask], ..., tn[mask])
    // ret[~mask] = f2(t1[~mask], ..., tn[~mask])
    auto not_mask = mask.logical_not();
    return at::empty_like(t)
        .index_put_({mask}, f1(t_masked, ts.index({mask})...))
        .index_put_(
            {not_mask}, f2(t.index({not_mask}), ts.index({not_mask})...));
  }
}

Tensor linalg_det_jvp(
    const Tensor& dA,
    const Tensor& det,
    const Tensor& LU,
    const Tensor& pivots,
    const bool use_A_T) {
  // (d det)_A(E) = tr(A^{-1}E)*det
  // We use that the determinant is C^1 to approximate the gradient of singular
  // inputs Since we never differentiate over forward AD, we don't need to deal
  // with further gradients, as we do in grad_backward
  auto eps = at::native::_get_epsilon(c10::toRealValueType(LU.scalar_type()));
  auto LU_ =
      LU + at::diag_embed(at::where(LU.diagonal(0, -2, -1) == 0., eps, 0.));
  auto AinvE =
      at::linalg_lu_solve(LU_, pivots, dA, /*left=*/true, /*adjoint=*/use_A_T);
  return AinvE.diagonal(0, -2, -1).sum(-1) * det;
}

Tensor linalg_det_backward(
    const Tensor& grad,
    const Tensor& det,
    const Tensor& A,
    const Tensor& LU,
    const Tensor& pivots) {
  at::NoTF32Guard disable_tf32;
  // A.numel() == 0 necessary for the singular case
  if (!grad.defined() || A.sym_numel() == 0) {
    return {};
  }

  // The gradient G is the matrix solving
  // A.mH G = det(A).conj() * grad * I
  auto d_diag = grad * det.conj();
  // Optimisation, Make it F-transposed as it's what lu_solve expects
  auto d = at::diag_embed(d_diag.unsqueeze(-1).expand_as(pivots)).mT();
  auto eps = at::native::_get_epsilon(c10::toRealValueType(LU.scalar_type()));

  // Optimisation if we are not going to compute higher-order gradients
  if (!at::GradMode::is_enabled()) {
    // The formula is given by the solution of AX = det.conj() * det * I when A
    // is invertible det is C^1, so if it's not invertible, we can apply a
    // perturbation to the LU decomposition and use the resulting matrix as a
    // non-singular approximation
    auto LU_ =
        LU + at::diag_embed(at::where(LU.diagonal(0, -2, -1) == 0., eps, 0.));
    auto use_A_T = A.is_contiguous() && !A.is_complex();
    return at::linalg_lu_solve(
        LU_, pivots, d, /*left=*/true, /*adjoint=*/!use_A_T);
  } else {
    // If we want to compute higher-order gradients, we need to recompute the
    // LU decomposition so that autograd computes the correct gradients wrt
    // to A (cf. solve_backward)
    auto non_singular =
        [](const Tensor& A, const Tensor& d, const Tensor& /*grad*/) {
          return at::linalg_solve(A.mH(), d);
        };

    // The derivative may be then computed explicitly by noting that the
    // gradient of the derivative of the determinant is given in terms of the
    // adjugate of a matrix. The adjugate of a singular matrix may be computed
    // as per https://nhigham.com/2020/06/16/what-is-the-adjugate-of-a-matrix/
    auto singular = [](const Tensor& A,
                       const Tensor& /*d*/,
                       const Tensor& grad) {
      auto [U, S, Vh] = at::linalg_svd(A);
      auto alpha = (at::linalg_det(U) * at::linalg_det(Vh)).conj() * grad;
      auto D = prod_safe_zeros_backward(alpha.unsqueeze(-1), S, S.dim() - 1);
      return (U * D.unsqueeze(-2)).matmul(Vh);
    };

    // We could use the singular formula for all inputs but we try to filter out
    // some inputs via the masking, as computing an SVD is about 100 times
    // slower than computing an lu_solve on GPU
    // For tensor subclasses, we can't call masked_fmap as it calls
    // index({mask}) which needs to call item to compute the number of elements
    // in the result.

    if (areAnyTensorSubclassLike({A, d, grad})) {
      return singular(A, d, grad);
    } else {
      return masked_fmap(
          det.abs() < 100. * eps, singular, non_singular, A, d, grad);
    }
  }
}

std::tuple<Tensor, Tensor> slogdet_jvp(
    const Tensor& LU,
    const Tensor& pivots,
    const Tensor& dA,
    const Tensor& sign,
    const bool use_A_T) {
  // No need to handle the singular case separately as we do in det since
  // this function is not differentiable on singular matrices
  auto trAinvE = at::linalg_lu_solve(LU, pivots, dA, /*left*/ true, use_A_T)
                     .diagonal(0, -2, -1)
                     .sum(-1);
  if (LU.is_complex()) {
    auto i = c10::complex<double>{0.0, 1.0};
    return std::make_tuple(at::imag(trAinvE) * (i * sign), at::real(trAinvE));
  } else {
    return std::make_tuple(
        at::_efficientzerotensor(sign.sizes(), sign.options()), trAinvE);
  }
}

Tensor slogdet_backward(
    const Tensor& grad_sign,
    const Tensor& grad_logabsdet,
    const Tensor& A,
    const Tensor& signdet,
    const Tensor& LU,
    const Tensor& pivots) {
  // We compute the complex case, as the real case follows from it
  // Forward AD
  // d (logabsdet)_A(E) = Re(tr(A^{-1}E))
  // d (signdet)_A(E) = sgn * Im(tr(A^{-1}E)) * i
  // So
  // d (logabsdet)*_A(g) = gA^{-H}
  // Now, to compute the adjoint of d(signdet), note that
  // Re(z * Im(w)) = Re(-Re(z)iw)
  // So, let g \in C,
  // <g, d(signdet)_A(E)> = Re(g.conj() * sgn * i * Im(A^{-1}E))
  //                      = Re(Re(g.conj() * sgn * i) * -i * A^{-1}E)
  //                      = Re(Im(g.conj() * sgn) * i * A^{-1}E)
  //                      = <Im(g.conj() * sgn) * -i * A^{-H}, E>
  // As such,
  // (d slogabs)*_A(g_sign, g_abs) = (g_abs - g_sign.conj() * sgn) * A^{-H}

  if (!grad_sign.defined() && !grad_logabsdet.defined()) {
    return {};
  }

  auto is_complex = A.is_complex();

  // In the real case grad_sign is always zero
  if (!is_complex && !grad_logabsdet.defined()) {
    return {};
  }

  auto g = grad_logabsdet;
  if (is_complex) {
    if (grad_sign.defined()) {
      auto i = c10::complex<double>{0.0, 1.0};
      if (g.defined()) {
        g = g - i * at::imag(grad_sign.conj() * signdet);
      } else {
        g = -i * at::imag(grad_sign.conj() * signdet);
      }
    } else {
      // Cast to complex explicitly
      g = g.to(A.scalar_type());
    }
  }

  // No need to handle the singular case separately here (as we do in det)
  // since this function is not differentiable on singular matrices
  // Optimisation, Make it F-transposed as it's what lu_solve expects
  auto d = at::diag_embed(g.unsqueeze(-1).expand_as(pivots)).mT();
  if (!at::GradMode::is_enabled()) {
    auto use_A_T = A.is_contiguous() && !A.is_complex();
    return at::linalg_lu_solve(
        LU, pivots, d, /*left=*/true, /*adjoint=*/!use_A_T);
  } else {
    // If we want to compute further gradients, we need to recompute the LU
    // decomposition so that autograd computes the correct gradients wrt to A
    // (cf. solve_backward)
    return at::linalg_solve(A.mH(), d);
  }
}

// Reference:
// https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
// Sec. 2.3.1 Matrix inverse product
std::tuple<Tensor, Tensor> triangular_solve_backward(
    const Tensor& grad_x,
    const Tensor& grad_m,
    const Tensor& b,
    const Tensor& a,
    const Tensor& x,
    const bool upper,
    const bool transpose,
    const bool unitriangular,
    std::array<bool, 2> output_mask) {
  at::NoTF32Guard disable_tf32;
  Tensor grad_b, grad_a;
  if (grad_x.defined() || grad_m.defined()) {
    if (grad_x.defined()) {
      grad_b = std::get<0>(
          grad_x.triangular_solve(a.conj(), upper, !transpose, unitriangular));
      if (output_mask[1]) {
        grad_a =
            transpose ? -x.conj().matmul(grad_b.mT()) : -grad_b.matmul(x.mH());
        if (upper) {
          grad_a = grad_a.triu((int)unitriangular);
        } else {
          grad_a = grad_a.tril(-((int)unitriangular));
        }
      }
    }
    if (!grad_a.defined()) {
      grad_a = at::zeros({1}, a.options()).expand_as(a);
    }
    if (!grad_b.defined()) {
      grad_b = at::zeros({1}, b.options()).expand_as(b);
    }
    if (output_mask[1] && grad_m.defined()) {
      grad_a = grad_a.add(grad_m);
    }
  }
  return std::tuple<Tensor, Tensor>{grad_b, grad_a};
}

Tensor triangular_solve_jvp(
    const Tensor& X,
    const Tensor& A,
    const Tensor& dA,
    const Tensor& dB,
    const bool upper,
    const bool transpose,
    const bool unitriangular) {
  return generic_solve_jvp(
      [&](const Tensor& A, const Tensor& dB, const Tensor& dA_contrib) {
        return std::get<0>(at::triangular_solve(
            dB - dA_contrib, A, upper, transpose, unitriangular));
      },
      X,
      A,
      dA,
      dB);
}

Tensor linalg_solve_triangular_forward_AD(
    const Tensor& A_t,
    const Tensor& B_t,
    const Tensor& A,
    const Tensor& X,
    const bool upper,
    const bool left,
    const bool unitriangular) {
  at::NoTF32Guard disable_tf32;
  // The forward AD formula (for left = true) is A^{-1}(B_t - A_tX)
  // For the derivation see:
  // [Note: Forward / Backward AD solve_triangular]
  const Tensor proj_A_t = upper ? A_t.triu(static_cast<int>(unitriangular))
                                : A_t.tril(-static_cast<int>(unitriangular));
  const Tensor X_t =
      B_t - (left ? at::matmul(proj_A_t, X) : at::matmul(X, proj_A_t));
  return at::linalg_solve_triangular(A, X_t, upper, left, unitriangular);
}

std::tuple<Tensor, Tensor> linalg_solve_triangular_backward(
    const Tensor& grad,
    const Tensor& A,
    const Tensor& X,
    const bool upper,
    const bool left,
    const bool unitriangular,
    std::array<bool, 2> output_mask) {
  at::NoTF32Guard disable_tf32;
  const bool A_requires_grad = output_mask[0];
  const bool B_requires_grad = output_mask[1];
  // [Note: Forward / Backward AD solve_triangular]
  // Assume left=true for simplicity.
  // Remark: A solver computes A^{-1}B
  //
  // Forward AD:
  // If f(A) = A^{-1}, differentiating the equation A^{-1}A = I_n gives
  // (df)_A(E) = -A^{-1}EA^{-1}
  // As such, if g(A,B) = A^{-1}B,
  // (dg)_(A,B)(E_A, E_B) = -A^{-1}E_AA^{-1}B + A^{-1}E_B
  //                      = A^{-1}(E_B - E_AX)

  // Backward AD:
  // Denoting the gradients by G_A, G_B, we solve above to give
  // G_B = A^{-H}G_X
  // G_A = -A^{-H}G_XX^H = -G_B X^H
  //
  // Note that you don't need to store B for forward nor backward
  //
  // These formulas work for a general solver of linear equations.
  // Let's prove now that when A is triangular, G_A is the projection onto the
  // triangular matrices of the formula above, i.e. simply taking triu (resp.
  // tril) in the formula above. This is because, since the triangular matrices
  // form a vector space, the tangent space at any point is itself the space of
  // triangular matrices. The result follows from a reasoning as that at the end
  // of [Note: eigh backward] Something similar happens for `unitriangular`,
  // only that int his case the tangent space is the set of lower-triangular
  // matrices with zeros on the diagonal.

  if (!grad.defined() || (!A_requires_grad && !B_requires_grad)) {
    return std::make_tuple(Tensor{}, Tensor{});
  }
  // We always need to comput G_B
  const Tensor A_H = A.mH();
  const Tensor G_B =
      at::linalg_solve_triangular(A_H, grad, !upper, left, unitriangular);

  if (A_requires_grad) {
    const Tensor X_H = X.mH();
    Tensor G_A = left ? -at::matmul(G_B, X_H) : -at::matmul(X_H, G_B);
    G_A = upper ? G_A.triu(static_cast<int>(unitriangular))
                : G_A.tril(-static_cast<int>(unitriangular));
    return std::make_tuple(G_A, B_requires_grad ? G_B : Tensor{});
  } else {
    return std::make_tuple(Tensor{}, G_B);
  }
}

std::tuple<Tensor, Tensor> cholesky_solve_backward(
    const Tensor& grad_x,
    const Tensor& self,
    const Tensor& input2,
    const Tensor& result,
    const bool upper,
    std::array<bool, 2> output_mask) {
  at::NoTF32Guard disable_tf32;
  Tensor grad_self, grad_input2;
  if (grad_x.defined()) {
    grad_self = grad_x.cholesky_solve(input2, /*upper=*/upper);

    if (output_mask[1]) {
      Tensor common_term = at::matmul(grad_self, result.mH());
      common_term = common_term + common_term.mH();

      if (upper) {
        grad_input2 = -at::matmul(input2, common_term);
      } else {
        grad_input2 = -at::matmul(common_term, input2);
      }
    }
  }
  return std::tuple<Tensor, Tensor>{grad_self, grad_input2};
}

Tensor cholesky_solve_jvp(
    const Tensor& X,
    const Tensor& U,
    const Tensor& dU,
    const Tensor& dB,
    const bool upper) {
  at::NoTF32Guard disable_tf32;
  auto dK = upper ? dU.mH().matmul(U) : dU.matmul(U.mH());
  auto dA = dK + dK.mH();
  return generic_solve_jvp(
      [&](const Tensor& A, const Tensor& dB, const Tensor& dA_contrib) {
        return at::cholesky_solve(dB - dA_contrib, A, upper);
      },
      X,
      /*A=*/U,
      dA,
      dB);
}

Tensor fft_c2r_backward(
    const Tensor& grad,
    IntArrayRef dim,
    int64_t normalization) {
  // Forward is C2R (onesided)
  // Think of onesided C2R irfft as
  //    1. fill the other half by conjugate symmetry
  //    2. inverse C2C ifft
  //    3. discard the complex dimension
  // So backward is
  //    1. R2C rfft (essentially add dummy complex dimension, and dft)
  //    2. accumulate gradient by conjugate symmetry
  //       since rfft results follow conjugate symmetry, we only need to
  //       double some entries from onesided rfft results, i.e., the ones with
  //       their reflected indices also landing out of the onesided range. So
  //       consider the index of last dim:
  //           i.   idx = 0.
  //                Reflected to (N - 0) % N = 0. Not doubled.
  //           ii   0 < idx < floor(N/2) (last).
  //                N > N - idx > ceil(N/2)
  //                Reflected to ()
  //           iii. idx = floor(N/2) = N/2 (last) when N even.
  //                Reflected to (N - N/2) % N = N/2. Not doubled.
  //           iv.  idx = floor(N/2) = (N-1)/2 (last) when N odd.
  //                Reflected to (N - (N-1)/2) % N = (N+1)/2. Doubled.
  //       Therefore, needs to double
  //           idx = 1, 2, ..., N/2 - 1     when N even
  //           idx = 1, 2, ..., (N-1)/2     when N odd
  //       that is
  //           idx = 1, 2, ..., N - (floor(N/2) + 1)
  //               = 1, 2, ..., N - onesided_length
  auto gI = at::_fft_r2c(grad, dim, normalization, /*onesided=*/true);

  auto double_length = grad.sym_size(dim.back()) - gI.sym_size(dim.back());
  if (double_length > 0) { // also covers case when signal size is zero
    gI.narrow_symint(dim.back(), 1, double_length).mul_(2);
  }
  return gI;
}

Tensor fft_r2c_backward(
    const Tensor& grad,
    at::IntArrayRef dim,
    int64_t normalization,
    bool onesided,
    const c10::SymInt& last_dim_size) {
  if (!onesided) {
    return at::real(at::_fft_c2c(grad, dim, normalization, /*forward=*/false));
  }

  // Forward is R2C (onesided)
  // Think of onesided R2C rfft as
  //     1. view as complex numbers (fill complex dim with zeros)
  //     2. C2C fft
  //     3. discard half of results
  // So backward is
  //     1. fill the other half with zeros (with `zero_grad_shape` below)
  //        (C2C ifft only take twosided inputs so we need to fill here)
  //     2. inverse C2C ifft
  //     3. discard the complex dim
  auto half_sizes = grad.sym_sizes();
  std::vector<c10::SymInt> new_grad_shape(half_sizes.begin(), half_sizes.end());
  const auto last_dim =
      at::maybe_wrap_dim(dim.back(), static_cast<int64_t>(half_sizes.size()));
  new_grad_shape[last_dim] = last_dim_size;

  const auto zero_length = last_dim_size - grad.sym_size(dim.back());
  auto complex_full_grad =
      zero_length > 0 ? grad.new_zeros_symint(new_grad_shape) : grad;
  if (zero_length > 0) {
    complex_full_grad.slice_symint(last_dim, 0, half_sizes[last_dim])
        .copy_(grad);
  }
  return at::real(
      at::_fft_c2c(complex_full_grad, dim, normalization, /*forward=*/false));
}

// Helper for batchnorm_double_backward
static Tensor sum_exclude_dim1(const Tensor& to_sum, bool keepdim = true) {
  auto r = to_sum.sum(0, keepdim);
  int64_t start_point_exclusive = keepdim ? 1 : 0;
  for (int64_t dim = r.dim() - 1; dim > start_point_exclusive; dim--) {
    r = r.sum(dim, keepdim);
  }
  return r;
}

// Helper for batchnorm_double_backward
// similar to expand_as below, but doesn't do the expand_as; operates as if
// reductions were done with keepdim=True
static Tensor unsqueeze_dim1(const Tensor& src, const Tensor& target) {
  auto src_expanded = src;
  while (src_expanded.sizes().size() < target.sizes().size() - 1) {
    src_expanded = src_expanded.unsqueeze(1);
  }
  if (src_expanded.sizes().size() == target.sizes().size() - 1) {
    src_expanded = src_expanded.unsqueeze(0);
  }
  return src_expanded;
}

// Helper for batchnorm_double_backward
// because gamma/ggG/ggB are 1-dimensional and represent dim==1, we can't
// do a straight expansion because it won't follow the broadcasting rules.
static Tensor expand_as_dim1(const Tensor& src, const Tensor& target) {
  auto src_expanded = src;
  while (src_expanded.sizes().size() < target.sizes().size() - 1) {
    src_expanded = src_expanded.unsqueeze(1);
  }
  return src_expanded.expand_as(target);
}

std::tuple<Tensor, Tensor, Tensor> batchnorm_double_backward(
    const Tensor& input,
    const std::optional<Tensor>& gamma,
    const Tensor& ggI,
    const Tensor& ggG,
    const Tensor& ggB,
    const Tensor& gO,
    const std::optional<Tensor>& running_mean,
    const std::optional<Tensor>& running_var,
    bool training,
    double eps,
    const std::optional<Tensor>& save_mean,
    const std::optional<Tensor>& save_invstd,
    std::array<bool, 3> output_mask) {
  bool affine = isDefined(gamma);
  // TODO: Do we have a ScalarOrTensor type?  Would such a thing exist?
  Tensor gamma_expanded;
  Tensor ggG_expanded, ggB_expanded;
  if (affine) {
    // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
    gamma_expanded = expand_as_dim1(*gamma, input);
    if (ggG.defined()) {
      ggG_expanded = expand_as_dim1(ggG, input);
    }
    if (ggB.defined()) {
      ggB_expanded = expand_as_dim1(ggB, input);
    }
  } else {
    gamma_expanded = at::ones({}, input.options());
  }

  // define some terms we will reuse
  auto M = input.size(0);
  for (auto s : input.sizes().slice(2)) {
    M *= s;
  }
  // for half inputs, save_mean, save_invstd are float (ideally, we would cast
  // everything else, but not now)
  auto mu = unsqueeze_dim1(
      training ? toNonOptTensor(save_mean).to(input.scalar_type())
               : toNonOptTensor(running_mean),
      input);
  auto input_sub_mu = input - mu;
  auto sigma2_eps_neg_1_2 = unsqueeze_dim1(
      training ? toNonOptTensor(save_invstd).to(input.scalar_type())
               : toNonOptTensor(running_var).add(Scalar(eps)).pow(-0.5),
      input);
  auto sigma2_eps_neg_1 = sigma2_eps_neg_1_2.pow(2);
  auto sigma2_eps_neg_3_2 = sigma2_eps_neg_1_2.pow(3);

  // calculate gI
  auto input_mu_sigma2_neg_3_2 = input_sub_mu * sigma2_eps_neg_3_2;
  auto gOinmu_sum = sum_exclude_dim1(gO * input_sub_mu);
  auto gO_sum = sum_exclude_dim1(gO);

  Tensor gI;
  if (ggI.defined() && training) {
    auto ggI_sum = sum_exclude_dim1(ggI);
    auto ggIinmu_sum = sum_exclude_dim1(ggI * input_sub_mu);
    auto all_sub = ((ggI_sum * gO_sum).div_(M))
                       .sub_(sum_exclude_dim1(gO * ggI))
                       .add_((sigma2_eps_neg_1 * gOinmu_sum * ggIinmu_sum)
                                 .mul_(3. / static_cast<double>(M)));
    auto gI_0t = (input_mu_sigma2_neg_3_2 * all_sub).div_(M);
    auto gI_1t =
        (ggIinmu_sum * sigma2_eps_neg_3_2).div_(M) * (gO_sum.div(M) - gO);
    auto gI_2t =
        (gOinmu_sum * sigma2_eps_neg_3_2).div_(M) * (ggI_sum.div(M) - ggI);
    gI = gamma_expanded * (gI_0t.add_(gI_1t).add_(gI_2t));
  }

  // add contribution of gamma term to gI
  Tensor gI_G_term;
  if (affine && ggG.defined()) {
    if (training) {
      auto t0 = gO * sigma2_eps_neg_1_2;
      auto t1 = (sigma2_eps_neg_1_2 * gO_sum).div_(-M);
      auto t2 = (input_mu_sigma2_neg_3_2 * sum_exclude_dim1(gO * input_sub_mu))
                    .div_(-M);
      gI_G_term = ggG_expanded * (t0.add_(t1).add_(t2));
      gI = gI.defined() ? gI.add_(gI_G_term) : gI_G_term;
    } else {
      gI_G_term = ggG_expanded * sigma2_eps_neg_1_2 * gO;
      gI = gI.defined() ? gI.add_(gI_G_term) : gI_G_term;
    }
  }

  // this is the first backward's grad_input
  auto first_back_grad_input = [&](const Tensor& gO,
                                   const Tensor& gamma) -> Tensor {
    auto h0 = (gamma * sigma2_eps_neg_1_2).div_(M);
    auto h1 = (M * gO)
                  .sub_(sum_exclude_dim1(gO))
                  .sub_(
                      input_sub_mu.mul(sigma2_eps_neg_1) *
                      sum_exclude_dim1(gO * input_sub_mu));
    return h0 * h1;
  };

  // calculate gG
  Tensor gG;
  if (affine && ggI.defined()) {
    if (training) {
      // gG is just the first backwards with the gamma term removed (then shaped
      // properly)
      gG = ggI *
          first_back_grad_input(gO, at::ones({}, sigma2_eps_neg_1_2.options()));
      gG = sum_exclude_dim1(gG, false);
    } else {
      gG = sum_exclude_dim1(ggI * gO * sigma2_eps_neg_1_2, false);
    }
  }

  // calculate ggO
  Tensor ggO;
  // contribution of input term
  if (ggI.defined()) {
    if (training) {
      ggO = first_back_grad_input(ggI, gamma_expanded);
    } else {
      ggO = ggI * sigma2_eps_neg_1_2 * gamma_expanded;
    }
  }
  if (ggG.defined()) {
    auto ggO_G_term = ggG_expanded * input_sub_mu * sigma2_eps_neg_1_2;
    ggO = ggO.defined() ? ggO.add_(ggO_G_term) : ggO_G_term;
  }
  if (ggB.defined()) {
    auto ggO_B_term = std::move(ggB_expanded);
    ggO = ggO.defined() ? ggO.add_(ggO_B_term) : ggO_B_term;
  }

  if (output_mask[1] && !gG.defined()) {
    AT_ASSERTM(affine, "gamma should always be defined when it requires grad");
  }

  return std::tuple<Tensor, Tensor, Tensor>{gI, gG, ggO};
}

std::tuple<Tensor, Tensor, Tensor> layer_norm_double_backward(
    const Tensor& input_t,
    const std::optional<Tensor>& gamma,
    const Tensor& ggI,
    const Tensor& ggG,
    const Tensor& ggB,
    const Tensor& gO_t,
    const Tensor& save_mean_t,
    const Tensor& save_invstd_t,
    c10::SymIntArrayRef normalized_shape,
    std::array<bool, 3> output_mask) {
  const auto normalized_ndim = normalized_shape.size();
  const auto input_shape = input_t.sizes();
  const auto input_ndim = input_t.dim();
  const auto axis = input_ndim - normalized_ndim;
  const int64_t M =
      c10::multiply_integers(input_shape.cbegin(), input_shape.cbegin() + axis);
  const int64_t N =
      c10::multiply_integers(input_shape.cbegin() + axis, input_shape.cend());
  // printf("M: %ld, N: %ld", M, N);

  auto input = input_t.reshape({M, N});
  auto gO = gO_t.reshape({M, N});
  auto save_mean = save_mean_t.reshape({M, 1});
  auto save_invstd = save_invstd_t.reshape({M, 1});

  bool affine = isDefined(gamma);
  Tensor gamma_expanded;
  Tensor ggG_expanded, ggB_expanded;
  if (affine) {
    // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
    gamma_expanded = gamma->reshape({1, N});
    if (ggG.defined()) {
      ggG_expanded = ggG.reshape({1, N});
    }
    if (ggB.defined()) {
      ggB_expanded = ggB.reshape({1, N});
    }
  } else {
    gamma_expanded = at::ones({1}, input.options());
  }

  Tensor ggI_expanded;
  if (ggI.defined()) {
    ggI_expanded = ggI.reshape({M, N});
  }

  // for half inputs, save_mean, save_invstd are float
  // (ideally, we would cast everything else, but not now)
  auto mu = save_mean.to(input.scalar_type());
  auto input_sub_mu = input - mu;
  auto sigma2_eps_neg_1_2 = save_invstd.to(input.scalar_type());
  auto sigma2_eps_neg_1 = sigma2_eps_neg_1_2.pow(2);
  auto sigma2_eps_neg_3_2 = sigma2_eps_neg_1_2.pow(3);

  Tensor gI;
  // calculate gI
  auto input_mu_sigma2_neg_3_2 = input_sub_mu * sigma2_eps_neg_3_2;

  if (ggI.defined()) {
    auto gxhat = gO * gamma_expanded;
    auto gxhat_mu_sum = (gxhat * input_sub_mu).sum(1, true);
    auto gxhat_sum = gxhat.sum(1, true);

    auto ggI_sum = ggI_expanded.sum(1, true);
    auto ggI_mu_sum = (ggI_expanded * input_sub_mu).sum(1, true);

    auto all_sub = ((ggI_sum * gxhat_sum).div_(N))
                       .sub_((ggI_expanded * gxhat).sum(1, true))
                       .add_((sigma2_eps_neg_1 * gxhat_mu_sum * ggI_mu_sum)
                                 .mul_(3. / static_cast<double>(N)));
    auto gI_0t = (input_mu_sigma2_neg_3_2 * all_sub).div_(N);
    auto gI_1t =
        (ggI_mu_sum * sigma2_eps_neg_3_2).div_(N) * (gxhat_sum.div(N) - gxhat);
    auto gI_2t = (gxhat_mu_sum * sigma2_eps_neg_3_2).div_(N) *
        (ggI_sum.div(N) - ggI_expanded);

    gI = (gI_0t.add_(gI_1t).add_(gI_2t));
  }

  // add contribution of gamma term to gI
  if (affine && ggG.defined()) {
    auto t0 = gO * ggG_expanded * sigma2_eps_neg_1_2;
    auto t1 = (sigma2_eps_neg_1_2 * (gO * ggG_expanded).sum(1, true)).div_(-N);
    auto t2 = (input_mu_sigma2_neg_3_2 *
               (gO * ggG_expanded * input_sub_mu).sum(1, true))
                  .div_(-N);
    auto gI_G_term = t0.add_(t1).add_(t2);
    gI = gI.defined() ? gI.add_(gI_G_term) : gI_G_term;
  }

  if (gI.defined()) {
    // printf("=== computing gI\n");
    gI = gI.reshape_as(input_t);
  }

  // this is the grad_input for the first backward function
  auto first_bwd_fn_grad_input = [&](const Tensor& gO_local,
                                     const Tensor& gamma_local) -> Tensor {
    auto h0 = (gamma_local * sigma2_eps_neg_1_2).div_(N);
    auto h1 = (N * gO_local)
                  .sub_(gO_local.sum(1, true))
                  .sub_(
                      input_sub_mu.mul(sigma2_eps_neg_1) *
                      (gO_local * input_sub_mu).sum(1, true));
    return h0 * h1;
  };

  // calculate gG
  Tensor gG;
  if (affine && ggI.defined()) {
    gG = first_bwd_fn_grad_input(
        ggI_expanded, at::ones({}, sigma2_eps_neg_1_2.options()));
    gG = (gO * gG).sum(0);
    // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
    gG = gG.reshape_as(*gamma);
  }

  // calculate ggO
  Tensor ggO;
  // contribution of input term
  if (ggI.defined()) {
    ggO = first_bwd_fn_grad_input(ggI_expanded, gamma_expanded);
  }
  if (ggG.defined()) {
    auto ggO_G_term = ggG_expanded * input_sub_mu * sigma2_eps_neg_1_2;
    ggO = ggO.defined() ? ggO.add_(ggO_G_term) : ggO_G_term;
  }
  if (ggB.defined()) {
    auto ggO_B_term = std::move(ggB_expanded);
    ggO = ggO.defined() ? ggO.add_(ggO_B_term) : ggO_B_term;
  }
  if (ggO.defined()) {
    ggO = ggO.expand({M, N}).reshape_as(input_t);
  }

  if (output_mask[1] && !gG.defined()) {
    AT_ASSERTM(affine, "gamma should always be defined when it requires grad");
  }

  return std::tuple<Tensor, Tensor, Tensor>{gI, gG, ggO};
}

std::tuple<Tensor, Tensor, Tensor>
infinitely_differentiable_native_group_norm_backward(
    const Tensor& dY,
    const Tensor& dmean,
    const Tensor& drstd,
    const Tensor& X,
    const Tensor& mean,
    const Tensor& rstd,
    const std::optional<Tensor>& gamma,
    c10::SymInt N,
    const c10::SymInt& C,
    c10::SymInt HxW,
    int64_t group,
    double eps,
    std::array<bool, 3> grad_input_mask) {
  const int64_t G = group;
  const auto D = C / G;
  c10::SymFloat s = c10::SymFloat(1.0) / c10::SymFloat(D * HxW);
  Tensor dX;
  Tensor dgamma;
  Tensor dbeta;
  const Tensor X_tensor = X.reshape_symint({N, G, D, HxW});
  const Tensor mean_tensor = mean.reshape_symint({N, G, 1, 1});
  const Tensor rstd_tensor = rstd.reshape_symint({N, G, 1, 1});
  Tensor dY_tensor;
  Tensor ds;
  Tensor db;
  if (dY.defined()) {
    dY_tensor = dY.reshape_symint({N, G, D, std::move(HxW)});
    ds = (dY_tensor * X_tensor).sum(3).unsqueeze_(-1);
    db = dY_tensor.sum(3).unsqueeze_(-1);
  }
  if (grad_input_mask[0]) {
    Tensor gamma_tensor;
    if (isDefined(gamma)) {
      // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
      gamma_tensor = gamma->reshape_symint({1, G, D, 1});
    }
    const Tensor var =
        ((rstd_tensor * rstd_tensor).reciprocal_() - eps).clamp_min(0);
    const Tensor rstd_cube = rstd_tensor * rstd_tensor * rstd_tensor;
    Tensor dvar;
    if (drstd.defined()) {
      dvar = -0.5 * rstd_cube * drstd.view_symint({N, G, 1, 1});
    }
    if (dY.defined()) {
      const Tensor a =
          isDefined(gamma) ? rstd_tensor * gamma_tensor : rstd_tensor;
      Tensor b = (isDefined(gamma) ? (ds * gamma_tensor).sum(2) : ds.sum(2))
                     .unsqueeze_(-2);
      Tensor c = (isDefined(gamma) ? (db * gamma_tensor).sum(2) : db.sum(2))
                     .unsqueeze_(-2);
      b = (c * mean_tensor - b) * rstd_cube * s;
      c = -b * mean_tensor - c * rstd_tensor * std::move(s);
      dX = a * dY_tensor + b * X_tensor + c;
      if (dmean.defined() && drstd.defined()) {
        dX += var_mean_backward(
            dvar,
            dmean.view_symint({std::move(N), G, 1, 1}),
            X_tensor,
            IntArrayRef{2, 3},
            0,
            true);
      }
      dX = dX.reshape_as(X);
    } else if (dmean.defined() && drstd.defined()) {
      dX = var_mean_backward(
               dvar,
               dmean.view_symint({std::move(N), G, 1, 1}),
               X_tensor,
               IntArrayRef{2, 3},
               0,
               true)
               .reshape_as(X);
    }
  }
  if (grad_input_mask[1] && dY.defined()) {
    dgamma = ((ds - db * mean_tensor) * rstd_tensor)
                 .sum(0)
                 .reshape_as(toNonOptTensor(gamma));
  }
  if (grad_input_mask[2] && dY.defined()) {
    dbeta = db.sum(0).reshape_as(toNonOptTensor(gamma));
  }

  return std::make_tuple(dX, dgamma, dbeta);
}

std::tuple<Tensor, Tensor, Tensor> _trilinear_backward(
    const Tensor& grad_out,
    const std::optional<Tensor>& i1,
    const std::optional<Tensor>& i2,
    const std::optional<Tensor>& i3,
    IntArrayRef expand1,
    IntArrayRef expand2,
    IntArrayRef expand3,
    IntArrayRef sumdim,
    std::array<bool, 3> grad_mask) {
  Tensor grad_i1, grad_i2, grad_i3;
  if (grad_out.defined()) {
    if (grad_mask[0])
      grad_i1 =
          // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
          at::_trilinear(grad_out, *i2, *i3, sumdim, expand2, expand3, expand1);
    if (grad_mask[1])
      grad_i2 =
          // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
          at::_trilinear(*i1, grad_out, *i3, expand1, sumdim, expand3, expand2);
    if (grad_mask[2])
      grad_i3 =
          // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
          at::_trilinear(*i1, *i2, grad_out, expand1, expand2, sumdim, expand3);
  }
  return std::tuple<Tensor, Tensor, Tensor>(grad_i1, grad_i2, grad_i3);
}

Tensor log1p_backward(const Tensor& grad, const Tensor& self) {
  // We must conditionally initialize this using to_dense if sparse, sparse
  // addition is not supported without exact shape match
  Tensor self_p1_conj;
  if (self.layout() == c10::kSparse || self.layout() == c10::kSparseCsr ||
      self.layout() == c10::kSparseCsc || self.layout() == c10::kSparseBsr ||
      self.layout() == c10::kSparseBsc) {
    // The warning only applies to the sparsity of self, dense grad is never
    // materialized so if self is strided and grad is sparse nothing unexpected
    // happens memory wise
    TORCH_WARN(
        "log1p_backward: received self with sparse layout, but backward requires materialization of a dense tensor with this shape");
    self_p1_conj = (self.to_dense() + 1).conj();
  } else {
    // Although calling self.to_dense() would just return self when it has
    // strided layout, that would breaks functorch tests.
    self_p1_conj = (self + 1).conj();
  }
  if (grad.layout() == c10::kSparse || grad.layout() == c10::kSparseCsr ||
      grad.layout() == c10::kSparseCsc || grad.layout() == c10::kSparseBsr ||
      grad.layout() == c10::kSparseBsc) {
    // If grad is sparse we can't divide by the n-d (self + 1).conj(), so we
    // must multiply by the recipricol, layout of grad is preserved which is
    // important to gradcheck
    return grad * self_p1_conj.reciprocal_();
  }
  return grad / self_p1_conj;
}

Tensor sinc_backward(const Tensor& grad, const Tensor& self) {
  auto self_pi = self * M_PI;
  auto self_squared_pi = self * self * M_PI;
  auto out = grad *
      ((self_pi * self_pi.cos() - self_pi.sin()) / self_squared_pi).conj();
  return at::where(self_squared_pi == 0.0, at::zeros({}, grad.options()), out);
}

// Because the backward of pad(input, pads) is just pad(grad_output, [-p for p
// in pads])
Tensor constant_pad_nd_backward(const Tensor& grad, c10::SymIntArrayRef pad) {
  auto negated_pad = pad.vec();
  std::transform(
      negated_pad.cbegin(),
      negated_pad.cend(),
      negated_pad.begin(),
      // NOLINTNEXTLINE(modernize-use-transparent-functors)
      std::negate<c10::SymInt>());
  return at::constant_pad_nd_symint(grad, negated_pad, 0);
}

Tensor embedding_dense_double_backward_symint(
    const Tensor& grad,
    const Tensor& indices,
    const c10::SymInt& padding_idx) {
  // since first backward takes care of scaling by frequency,
  // we don't need to worry about it here.
  auto gg_weight = grad.index_select(0, indices.reshape(-1));

  // reshape gradient as per the shape of indices
  auto size = indices.sizes().vec();
  size.push_back(-1);

  if (padding_idx >= 0) {
    gg_weight.masked_fill_((indices == padding_idx).reshape({-1, 1}), 0);
  }
  return gg_weight.view(size);
}

Tensor index_backward(
    Tensor zeros_like_self,
    const torch::List<std::optional<Tensor>>& indices,
    const Tensor& grad) {
  return (areAnyTensorSubclassLike({zeros_like_self, grad}) ||
          areAnyOptionalTensorSubclassLike(indices))
      ? zeros_like_self.index_put(indices, grad, true)
      : at::_index_put_impl_(zeros_like_self, indices, grad, true, true);
}

Tensor _cudnn_ctc_loss_backward(
    const Tensor& grad_out,
    const Tensor& loss,
    const Tensor& raw_grad,
    bool zero_infinity) {
  if (zero_infinity) {
    return at::where(
        loss.unsqueeze(0).unsqueeze(2) == 0,
        at::zeros({}, raw_grad.options()),
        raw_grad * grad_out.unsqueeze(0).unsqueeze(2));
  } else {
    return raw_grad * grad_out.unsqueeze(0).unsqueeze(2);
  }
}

bool any_variable_defined(const variable_list& variables) {
  for (const auto& variable : variables) {
    if (variable.defined()) {
      return true;
    }
  }
  return false;
}

// Derivations for the householder_product.backward method.
//
// Given a sequence of vectors v_1, ..., v_n and a sequence of scalars tau_1,
// ..., tau_k, the torch.linalg.householder_product computes the firt n columns
// of the following product: Q = (I - tau_1 v_1 v_1^H) ... (I - tau_k v_k
// v_k^H). Let
//     H_i(sigma) := I - sigma v_i v_i^H, so Q = (H_1(sigma_1) ...
//     H_k(sigma_k)[:, :k]; H_i_minus = H_1(tau_1) ... H_{i - 1}(tau_{i - 1}),
//     with H_1_minus := I; H_i_plus = H_{i + 1}(tau_{i + 1}) ... H_k(tau_k)
//     with H_k_plus := I;
//
// Forward AD:
// dQ = sum_{i = 1}^k H_i_minus (-dtau_i v_i v_i^H - tau_i dv_i v_i^H - tau_i
// v_i dv_i^H) H_i_plus.
//
// Backward AD:
// Tr(Q_grad^H dQ) = sum_{i = 1}^k Tr(H_i_plus Q_grad^H H_i_minus (-dtau_i v_i
// v_i^H - tau_i dv_i v_i^H - tau_i v_i dv_i^H)). Let K_i := H_i_plus Q_grad^H
// H_i_minus, then the gradients are v_i_grad = (-tau_i v_i^H K_i)^H - tau_i K_i
// v_i, tau_i_grad = Tr(-v_i^H K_i v_i).conj(). NOTE: the algorithms ignores
// that only n columns of Q are observed, so there is no need in recomputing Q
// to full completion.
//
// Note that K_{i + 1} = H_{i + 1}^{-1} K_i H_i, so we can compute v_i_grad,
// tau_i_grad one by one by just efficiently updating K_i if that is possible.
// Multiplying with H_i from the right could be done with matrix-vector
// products, but what about the inverse H_{i + 1}^{-1} and does it even exist?
// Luckily, under some assumptions, H_{i + 1}^{-1} exists and admits a
// representation as H_i(sigma_i) for some sigma_i, so the left update is also
// could be done with matrix-vector and not matrix-matrix products.
//
// Let H(tau) := I - tau v v^H.
// H(tau) has eigenvalues 1 with multiplicity (m - 1) with eigenvectors
// orthogonal to v, and an eigenvalue (1 - tau ||v||^2) with the corresponding
// eigenvector v / ||v||. If (1 - tau ||v||^2) != 0, H(tau) is invertible. If (1
// - tau ||v||^2) != 0, then with sigma defined as sigma := tau / (||v||^2 tau -
// 1) we get that H(tau) H(sigma) = H(sigma) H(tau) = I, so H(sigma) is the
// inverse of H(tau).
//
// WARNING: the algorithm below assumes that H_i(tau_i) are all invertible, so
// it expects that (1 - tau_i ||v_i||^2) != 0 for all i.
// We would like to point out that if there is H_i(tau_i) which is not
// invertible, the householder_product is still differentiable! We will not be
// able to compute K_i efficiently in such cases, however, as evaluating of each
// K_i will amount to calls to ORGQR to be able to compute H_i_plus.

// This function computes either the product between
// (I - tau u v^H) and K (in-place or not) with `condition_with_I = true`, or
// between
// (-tau u v^H) and K (out-of-place only) with `condition_with_I = false`.
// Parameter `left` controls whether the matrix K is multiplied from the left or
// from the right.
// Additionally, when the computation is done in-place, we exploit that the
// first `k` coordinates of `u_full/v_full` are zeros.
static Tensor apply_simple_transformation(
    const c10::SymInt& m,
    const c10::SymInt& k,
    const Tensor& u_full,
    const Tensor& v_full,
    const Tensor& t,
    Tensor& K,
    bool modify_K_in_place = true,
    bool condition_with_I = true,
    bool left = true) {
  // we assume u_full is a vector of dimension (..., m, 1), t is a scalar of
  // dimension (..., 1)

  // TODO: matrix-vector products in the code below are dispatched to
  // matrix-matrix products. We either need to extend matmul to support batched
  // matrix-vector products, or implement a batched variant of mv. We could
  // enable mv for inputs which are not batched, but it is not done to eliminate
  // the code duplication.

  // returns (I - t u v^H) K or -t u v^H K
  if (left) {
    if (modify_K_in_place) {
      auto v = u_full.narrow_symint(-2, k, m - k);
      auto u = v_full.narrow_symint(-2, k, m - k)
                   .mH()
                   .matmul(K.narrow_symint(-2, k, m - k));
      K.narrow_symint(-2, k, m - k).sub_((t.unsqueeze(-1) * v) * u);
      return K;
    } else {
      auto transformation = (t.unsqueeze(-1) * u_full) * v_full.mH().matmul(K);
      return condition_with_I ? K - transformation : -transformation;
    }
  }
  // returns K (I - t u v^H) or -K t u v^H
  else {
    if (modify_K_in_place) {
      auto v = u_full.narrow_symint(-2, k, m - k);
      auto u =
          K.narrow_symint(-1, k, m - k)
              .matmul(t.unsqueeze(-1) * v_full.narrow_symint(-2, k, m - k));
      K.narrow_symint(-1, k, m - k).sub_(u * v.mH());
      return K;
    } else {
      auto transformation = K.matmul(t.unsqueeze(-1) * u_full) * v_full.mH();
      return condition_with_I ? K - transformation : -transformation;
    }
  }
};

std::tuple<Tensor, Tensor> householder_product_backward(
    const Tensor& grad,
    const Tensor& result,
    const Tensor& input_,
    const Tensor& tau,
    const bool flip_order) {
  // NOTE on `flip_order`: when flip_order is true,
  // the algorithm below reverses the processing direction from
  // range(k) to range(k - 1, -1, -1) in the main loop, and left/right
  // Householder projection applications get flipped.
  // The comments below about the algorithmic details assume flip_order = false.
  if (!grad.defined() || input_.sym_numel() == 0 || tau.sym_numel() == 0) {
    return std::tuple<Tensor, Tensor>(Tensor(), Tensor());
  }
  auto m = input_.sym_size(-2);
  // guard_int is due to irange calls below
  auto k = tau.sym_size(-1).guard_int(__FILE__, __LINE__);

  // forward operates only over the lower triangular part with the assumption
  // that the diagonal of input is filled with 1s.
  auto input = input_.tril(-1);
  input.diagonal(0, -2, -1).fill_(1.0);

  // compute sigma such that
  // H(sigma_i) == H(tau_i)^{-1}.
  // If the input to householder_product comes from GEQRF,
  // we will never encounter ||v_i||^2 tau_i == 1, so H(tau_i) will always be
  // invertible. This follows from the documentation
  // https://www.netlib.org/lapack/lug/node128.html, and tau always satisfying
  // the condition |tau|^2 ||v||^2 == 2 * Re(tau).
  auto input_first_k_cols = input.narrow(-1, 0, k);
  auto input_first_k_cols_norm_squared =
      (input_first_k_cols * input_first_k_cols.conj()).sum(-2);
  auto sigma = tau / (tau * input_first_k_cols_norm_squared - 1.0);

  auto K = result.matmul(grad.mH());

  // The algorithm updates K by multiplying it from the left/right with
  // Householder reflectors. If only single backward is run, we modify K
  // in-place and exploit triangularity of the input. With higher order
  // derivatives we cannot rewrite the storage of K, hence we use much less
  // efficient out-of-place methods.
  //
  // if only first-order derivative is expected, we can modify K in-place for
  // better performance
  bool modify_K_in_place = !at::GradMode::is_enabled();

  // This method exploits that at k-th iteration vector v_k has only elements
  // v_k[k:] which are non-zero.
  auto update_grad = [&m](
                         int64_t k,
                         const Tensor& v_full,
                         const Tensor& t,
                         const Tensor& K) -> std::tuple<Tensor, Tensor> {
    // v_full is a vector of dimension (..., m, 1), t is a scalar of dimension
    // (..., 1)
    auto v = v_full.narrow_symint(-2, k, m - k);
    auto vHK = v.mH().matmul(K.narrow_symint(-2, k, m - k));
    auto Kv = K.narrow_symint(-1, k, m - k).matmul(v);
    auto t_unsqueezed = t.unsqueeze(-1);
    auto v_grad = (-t_unsqueezed * vHK).conj().squeeze(-2) -
        (t_unsqueezed * Kv).squeeze(-1);
    auto tau_grad = -(vHK.narrow_symint(-1, k, m - k).matmul(v)).conj();
    return std::make_tuple(v_grad.unsqueeze(-1), tau_grad.squeeze(-1));
  };

  auto apply_householder_reflector = [m, modify_K_in_place](
                                         int64_t k,
                                         const Tensor& v_full,
                                         const Tensor& t,
                                         Tensor& K,
                                         bool left = true) -> Tensor {
    return apply_simple_transformation(
        m,
        k,
        v_full,
        v_full,
        t,
        K,
        modify_K_in_place,
        /*condition_with_I=*/true,
        left);
  };

  const auto flip_i = [flip_order, k](int64_t i) -> int64_t {
    return !flip_order ? i : k - i - 1;
  };
  const auto next_i = [flip_order](int64_t i) -> int64_t {
    return !flip_order ? ++i : --i;
  };
  const auto apply_left = !flip_order;

  // K <- H_0^{-1} @ K
  const auto zero_idx = flip_i(0);
  K = apply_householder_reflector(
      zero_idx,
      input.narrow(-1, zero_idx, 1),
      sigma.narrow(-1, zero_idx, 1),
      K,
      /*left=*/apply_left);

  Tensor input_grad, tau_grad;
  // For Composite Compliance, we can't copy a Subclass into a Regular Tensor,
  // so we use out-of-place ops with equivalent output.
  // NOTE: We can't use `new_zeros` directly as `input`, 'tau' or `grad` can
  // be Tensor Subclass and we don't want to make assumption about which
  // one to choose for creating output buffer.
  // eg. if both are BatchedTensor at different level.
  if (areAnyTensorSubclassLike({input, tau, K})) {
    // k + 1 if input_grads hold a matrix of zeros for inactive parts of input.
    auto input_grads = std::vector<Tensor>(k < input.sym_size(-1) ? k + 1 : k);
    auto tau_grads = std::vector<Tensor>(k);

    for (const auto i_idx : c10::irange(k)) {
      auto i = flip_i(i_idx);
      // NOTE: narrow will unsqueeze(-1)
      auto v_i = input.narrow(-1, i, 1);
      auto t_i = tau.narrow(-1, i, 1);

      std::tie(input_grads[i], tau_grads[i]) = update_grad(i, v_i, t_i, K);

      // K <- H_{i + 1}^{-1} @ K @ H_i
      if (i != flip_i(k - 1)) {
        auto i_next = next_i(i);
        auto v_i_next = input.narrow(-1, i_next, 1);
        auto s_i_next = sigma.narrow(-1, i_next, 1);
        K = apply_householder_reflector(
            i_next, v_i_next, s_i_next, K, /*left=*/apply_left);
        K = apply_householder_reflector(i, v_i, t_i, K, /*left=*/!apply_left);
      }
    }

    // Only first k columns are active in forward.
    // zero gradients for the inactive input.
    if (k < input.sym_size(-1)) {
      auto zero_grad_shape =
          at::SymDimVector(input_.sym_sizes().slice(0, input_.dim() - 1));
      zero_grad_shape.push_back(input.sym_size(-1) - k);
      auto zero_grad = at::zeros_symint(zero_grad_shape, input_.options());
      input_grads[k] = zero_grad;
    }

    input_grad = at::cat(input_grads, -1);
    tau_grad = at::cat(tau_grads, -1);
  } else {
    input_grad = at::zeros_like(input_);
    tau_grad = at::zeros_like(tau);
    for (const auto i_idx : c10::irange(k)) {
      auto i = flip_i(i_idx);
      // NOTE: narrow will unsqueeze(-1)
      auto v_i = input.narrow(-1, i, 1);
      auto t_i = tau.narrow(-1, i, 1);

      auto [v_i_grad, tau_i_grad] = update_grad(i, v_i, t_i, K);
      input_grad.select(-1, i).copy_(v_i_grad.squeeze(-1));
      tau_grad.select(-1, i).copy_(tau_i_grad.squeeze(-1));

      // K <- H_{i + 1}^{-1} @ K @ H_i
      if (i != flip_i(k - 1)) {
        auto i_next = next_i(i);
        auto v_i_next = input.narrow(-1, i_next, 1);
        auto s_i_next = sigma.narrow(-1, i_next, 1);
        K = apply_householder_reflector(
            i_next, v_i_next, s_i_next, K, /*left=*/apply_left);
        K = apply_householder_reflector(i, v_i, t_i, K, /*left=*/!apply_left);
      }
    }
  }

  // forward operates only over the lower-triangular part of the input
  // excluding the main diagonal, hence the gradient is also lower-triangular.
  input_grad.tril_(-1);

  return std::make_tuple(input_grad, tau_grad);
}

// We refer to the derivations described above the method
// `apply_simple_transformation`
Tensor householder_product_jvp(
    const Tensor& dV_,
    const Tensor& dtau,
    const Tensor& prod,
    const Tensor& V_,
    const Tensor& tau) {
  auto m = V_.sym_size(-2);
  auto k = tau.size(-1);

  // forward operates only over the lower triangular part with the assumption
  // that the diagonal of input is filled with 1s.
  auto V = V_.tril(-1);
  V.diagonal(0, -2, -1).fill_(1.0);
  auto dV = dV_.tril(-1);

  // compute sigma such that
  // H(sigma_i) == H(tau_i)^{-1}.
  // If the input to householder_product comes from GEQRF,
  // we will never encounter ||v_i||^2 tau_i == 1, so H(tau_i) will always be
  // invertible. This follows from the documentation
  // https://www.netlib.org/lapack/lug/node128.html, and tau always satisfying
  // the condition |tau|^2 ||v||^2 == 2 * Re(tau).
  auto V_first_k_cols = V.narrow(-1, 0, k);
  auto V_first_k_cols_norm_squared =
      (V_first_k_cols * V_first_k_cols.conj()).sum(-2);
  auto sigma = tau / (tau * V_first_k_cols_norm_squared - 1.0);

  auto apply_householder_reflector = [m](const Tensor& v_full,
                                         const Tensor& t,
                                         Tensor& K,
                                         bool left = true) -> Tensor {
    return apply_simple_transformation(
        // setting `modify_K_in_place = true` causes CUDA memory leaks in OpInfo
        // tests of forward AD for that reason we ignore `k` by passing zero
        m,
        /*k=*/0,
        v_full,
        v_full,
        t,
        K,
        /*modify_K_in_place=*/false,
        /*condition_with_I=*/true,
        left);
  };

  // computes (-t u v^H) K
  auto apply_simple_product = [m](const Tensor& u_full,
                                  const Tensor& v_full,
                                  const Tensor& t,
                                  Tensor& K) -> Tensor {
    return apply_simple_transformation(
        // since ``modify_K_in_place = false`, we can ignore `k` and pass
        // arbitrary value
        m,
        /*k=*/0,
        u_full,
        v_full,
        t,
        K,
        /*modify_K_in_place=*/false,
        /*condition_with_I=*/false,
        /*left=*/true);
  };

  auto H_plus = prod.detach().clone();
  IntArrayRef batch_vector_shape(V.sizes().data(), V.dim() - 1);
  auto H_minus =
      at::diag_embed(at::ones({1}, V.options()).expand(batch_vector_shape));

  auto dprod = at::zeros_like(prod);
  for (const auto i : c10::irange(k)) {
    auto v_i = V.narrow(-1, i, 1);
    auto dv_i = dV.narrow(-1, i, 1);
    auto tau_i = tau.narrow(-1, i, 1);
    auto dtau_i = dtau.narrow(-1, i, 1);
    auto sigma_i = sigma.narrow(-1, i, 1);

    H_plus = apply_householder_reflector(v_i, sigma_i, H_plus, /*left=*/true);

    // `H_minus_dH_i_H_plus` = H_1 * ... * H_{i-1} dH_i * H_{i+1} * ...
    auto H_minus_dH_i_H_plus = H_minus.matmul(
        apply_simple_product(v_i, v_i, dtau_i, H_plus) +
        apply_simple_product(dv_i, v_i, tau_i, H_plus) +
        apply_simple_product(v_i, dv_i, tau_i, H_plus));
    // For Composite Compliance, if `intermediate` is a Tensor-Subclass,
    // we use out-of-place variant of add.
    if (at::isTensorSubclassLike(H_minus_dH_i_H_plus)) {
      dprod = dprod.add(H_minus_dH_i_H_plus);
    } else {
      dprod.add_(H_minus_dH_i_H_plus);
    }

    H_minus = apply_householder_reflector(v_i, tau_i, H_minus, /*left=*/false);
  }

  return dprod;
}

std::tuple<Tensor, Tensor, Tensor> ormqr_backward(
    const Tensor& grad,
    const Tensor& result,
    const Tensor& self,
    const Tensor& tau,
    const Tensor& other,
    bool left,
    bool transpose,
    std::array<bool, 3> grad_output_mask) {
  Tensor self_grad, tau_grad, other_grad;

  if (!grad.defined()) {
    return std::make_tuple(self_grad, tau_grad, other_grad);
  }

  const auto self_requires_grad = grad_output_mask[0];
  const auto tau_requires_grad = grad_output_mask[1];
  const auto other_requires_grad = grad_output_mask[2];

  if (other_requires_grad) {
    other_grad = at::ormqr(self, tau, grad, left, !transpose);
  }
  if (self_requires_grad || tau_requires_grad) {
    if (left ^ transpose) {
      // Assume left = true and transpose = false. The case with
      // left = false and transpose = true is very much similar with just
      // transposed arguments passed into householder_product_backward.
      // Ormqr computes B = H_1 * ... * H_k * A.
      // The sensivity wrt H_i is given by (see notes in
      // householder_product_backward) Tr(H_i_plus B B_grad^H H_i_minus dH_i),
      // so, since householder_product_backward respects `for i in range(k)`, we
      // could reuse householder_product_backward with
      // householder_product_backward.grad = grad and
      // householder_product_backward.result = result.
      const auto hpb_grad = !transpose ? grad : grad.mH();
      const auto hpb_result = !transpose ? result : result.mH();
      std::tie(self_grad, tau_grad) =
          householder_product_backward(hpb_grad, hpb_result, self, tau);
    } else {
      // Assuming left = false and transpose = false. The case with
      // left = true and transpose = true is very much similar with just
      // transposed arguments passed into householder_product_backward.
      // In this case Ormqr computes B = H_1 * ... * H_k * A and the sensitivity
      // wrt H_i becomes Tr(H_i_plus B_grad^H B H_i_minus dH_k).
      // We could see that the role of `grad` and `result` in
      // householder_product_backward gets "swapped" and "transposed" and that
      // in order to compute H_k_grad efficiently we would need to compute grads
      // in reversed order (`for i in range(k - 1, -1, -1)`). Hence we reuse
      // householder_product_backward with householder_product_backward.grad =
      // result.mH, householder_product_backward.result = grad.mH,
      // householder_product_backward.flip_order = true.
      const auto hpb_grad = !transpose ? result.mH() : result;
      const auto hpb_result = !transpose ? grad.mH() : grad;
      std::tie(self_grad, tau_grad) = householder_product_backward(
          hpb_grad, hpb_result, self, tau, /*flip_order=*/true);
    }
  }

  return std::make_tuple(self_grad, tau_grad, other_grad);
}

std::tuple<Tensor, Tensor> polar_backward(
    const Tensor& grad,
    const Tensor& result) {
  Tensor grad_abs, grad_angle;
  if (grad.defined()) {
    auto grad_conj = grad.conj();
    grad_abs = at::real(grad_conj * at::sgn(result));
    auto result_mul_1_j = result * Scalar(c10::complex<double>{0.0, 1.0});
    grad_angle = at::real(grad_conj * result_mul_1_j);
  }
  return std::make_tuple(grad_abs, grad_angle);
}

Tensor i1_backward(
    const Tensor& grad,
    const Tensor& self,
    const Tensor& result) {
  return AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "i1_backward", [&]() {
    // For x = 0, the correct gradient is 0.5,
    // however due to floating point computation we get NaN.
    // So we manually update gradient for x=0
    auto eps = std::numeric_limits<scalar_t>::epsilon();
    auto self_is_not_tiny = self.abs() > eps;

    // Following `where` is needed as `where` computes gradients,
    // even for the part which didn't affect the output.
    // Look at https://github.com/pytorch/pytorch/issues/52248
    // Update if and when this is fixed.
    auto safe_self =
        at::where(self_is_not_tiny, self, at::full({}, eps, self.options()));
    auto gradx = (safe_self.i0() - (result * safe_self.reciprocal()));
    return grad *
        at::where(self_is_not_tiny, gradx, at::full({}, 0.5, self.options()));
  });
}

Tensor i1e_backward(
    const Tensor& grad,
    const Tensor& self,
    const Tensor& result) {
  return AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "i1e_backward", [&]() {
    // For x = 0, the correct gradient is 0.5,
    // however due to floating point computation we get NaN.
    // So we manually update gradient for x=0
    auto eps = std::numeric_limits<scalar_t>::epsilon();
    auto self_is_not_tiny = self.abs() > eps;

    // Following `where` is needed as `where` computes gradients,
    // even for the part which didn't affect the output.
    // Look at https://github.com/pytorch/pytorch/issues/52248
    // Update if and when this is fixed.
    auto safe_self =
        at::where(self_is_not_tiny, self, at::full({}, eps, self.options()));
    auto gradx =
        (at::special_i0e(safe_self) -
         result * (safe_self.sgn() + safe_self.reciprocal()));
    return grad *
        at::where(self_is_not_tiny, gradx, at::full({}, 0.5, self.options()));
  });
}

// lu_solve is a map (LU, P, B) -> (PLU)^{-1} B,
// where LU = L + U - I and P is a permutation matrix, and is fixed.
//
// Let 1 = ones_like(LU),
// 1_U = 1.triu(),
// 1_L = 1.tril(-1)
// * := the Hadamard (element-wise) product
//
// Forward AD:
//
// Let X := U^{-1} L^{-1} P^T B be the output of the function.
// Also, the LU input of the function could be represented as
// LU = (L - I) + U.
//
// Differentiating LU = L + U - I produces:
// dLU = dL + dU.
// Noting that dL and dU are lower- and upper-triangular, respectively,
// and that the diagonal of L is never explicitly exposed, so
// diag(dL) = 0, it follows
// dL = dLU * 1_L,
// dU = dLU * 1_U.
//
// Differentiating X = U^{-1} L^{-1} P^T B produces:
// dX = dU^{-1} L^{-1} P^T B + U^{-1} dL^{-1} P^T B + U^{-1} L^{-1} P^T dB
// Note that for any invertible matrix A we have A A^{-1} = I, hence
// dA A^{-1} + A dA^{-1} = 0 => dA^{-1} = -A^{-1} dA A^{-1}.
// Inserting it back into the definition of dX gives:
// dX = -U^{-1} dU U^{-1} L^{-1} P^T B - U^{-1} L^{-1} dL L^{-1} P^T B + U^{-1}
// L^{-1} P^T dB dX = -U^{-1} dU X - U^{-1} L^{-1} dL U X + U^{-1} L^{-1} P^T dB
//
// Backward AD:
//
// Using the definition of dL, dU from above:
// Tr(L_grad^H dL) + Tr(U_grad^H dU) = Tr(L_grad^H (dLU * 1_L)) + Tr(U_grad^H
// (dLU * 1_U))
//                                   = [using Tr(A (B * C)) = Tr((A * B^T) C)
//                                   = Tr((L_grad^H * 1_L^T) dLU) + Tr((U_grad^H
//                                   * 1_U^T) dLU),
// hence
// LU_grad = L_grad * 1_L + U_grad * 1_U (!!!)
//
// Then, transposing the formula for dX above we get:
// B_grad = P L^{-H} U^{-H} X_grad = lu_solve(X_grad, LU_data, LU_pivots,
// /*adjoint=*/true) U_grad = -U^{-H} X_grad X^H L_grad = L^{-H} U_grad U^H
// After inserting U_grad and L_grad into (!!!) we get the value for LU_grad.

Tensor linalg_lu_solve_LU(
    const Tensor& gX,
    const Tensor& LU,
    const Tensor& pivots,
    const Tensor& X,
    const bool left,
    const bool adjoint) {
  // From linalg_lu_solve_jvp we have that:
  // left = True, adjoint = True: A^HX = B
  // left = True, adjoint = False: AX = B
  // left = False, adjoint = True: AX^H = B^H
  // left = False, adjoint = False: A^HX^H = B^H
  // let op_1(A) = A^H or op_1(A) = A according to the list above
  // same with op_2(X) and op_3(B)
  // We have that letting S = lu_solve(LU, pivots, dB, left, adjoint)
  // the JVP formula reads
  // if left != adjoint:
  //   dX = op_2(-U^{-1}(dU + L^{-1}dL U)op_2(X)) + S
  // else:
  //   dX = op_2(op_1(-op_3(X)^H P(LdUU^{-1} + dL)L^{-1} P^T)) + S
  // So computing the adjoint of this operation we get that, using an auxiliary
  // variable gR if left != adjoint:
  //   gR = U^{-H}op_2(-gX)op_2(X)^H
  //   gU = gR.triu()
  //   gL = (L^{-H} gR U^H).tril(-1)
  // else:
  //   gR = -P^T op_3(X)op_1(op_2(gX))PL^{-H}
  //   gL = gR.tril(-1)
  //   gU = (L^H gR U^{-H}).triu()
  // gLU = gL + gU

  at::NoTF32Guard disable_tf32;
  auto [P, L, U] = at::lu_unpack(
      LU, pivots, /*unpack_data=*/true, /*unpack_pivots=*/left == adjoint);
  // TODO Optimise the order of the operations to avoid operating on large
  // tensors unnecessarily
  //      The logic should be: if n < k == left then multiply the gX and X first
  //      (as it's done now) Otherwise multiply them last
  if (left != adjoint) {
    // gR = U^{-H}op_2(-gX)op_2(X)^H
    auto gR = at::linalg_solve_triangular(
        U.mH(),
        -(left ? gX : gX.mH()).matmul(left ? X.mH() : X),
        /*upper*/ false);
    // gL = (L^{-H} gR U^H).tril(-1)
    auto gL = at::linalg_solve_triangular(
                  L.mH(),
                  gR.matmul(U.mH()),
                  /*upper*/ true,
                  /*left*/ true,
                  /*unitriangular*/ true)
                  .tril(-1);
    ;
    return gL + gR.triu();
  } else {
    // gR = -P^T op_3(X)op_1(op_2(gX))P
    auto gR =
        -P.mT().matmul(left ? X : X.mH()).matmul(left ? gX.mH() : gX).matmul(P);
    // gR = gR L^{-H}
    gR = at::linalg_solve_triangular(
        L.mH(), gR, /*upper*/ true, /*left*/ false, /*unitriangular*/ true);
    // gU = (L^H gR U^{-H}).triu()
    auto gU = at::linalg_solve_triangular(
                  U.mH(), L.mH().matmul(gR), /*upper*/ false, /*left*/ false)
                  .triu();
    return gR.tril(-1) + gU;
  }
}

Tensor linalg_lu_solve_jvp(
    const Tensor& X,
    const Tensor& LU,
    const Tensor& pivots,
    const Tensor& dLU,
    const Tensor& dB,
    const bool left,
    const bool adjoint) {
  // We write the derivation in terms of some adjoint operations, as otherwise
  // we would need to write down 4 different proofs with 4 different
  // implementations and it'd be painful to derive and maintain Below, we just
  // use that X -> X^H is linear, so it commutes with the derivative The
  // derivation follows by differentiating op_1(PLU)op_2(X) = op_3(B)

  // left = True, adjoint = True: A^HX = B
  // left = True, adjoint = False: AX = B
  // left = False, adjoint = True: AX^H = B^H
  // left = False, adjoint = False: A^HX^H = B^H
  // let op_1(A) = A^H or op_1(A) = A according to the list above
  // same with op_2(X) and op_3(B)
  // We have that letting S = lu_solve(LU, pivots, dB, left, adjoint)
  // the JVP formula reads
  // dX = op_2(op_1(-U^{-1}(dUU^{-1} + L^{-1}dL)L^{-1} P^T)op_3(B)) + S

  at::NoTF32Guard disable_tf32;
  auto S = at::linalg_lu_solve(LU, pivots, dB, left, adjoint);
  if (left != adjoint) {
    // We see that when left != adjoint, op_1(A) = A, and we can substitute
    // A^{-1}op_3(B) by op_2(X) dX = op_2(-U^{-1}(dU + L^{-1}dL U)op_2(X)) + S
    // Let R = -U^{-1}(dU + L^{-1}dL U)
    auto R = at::linalg_solve_triangular(
        LU,
        dLU.tril(-1),
        /*upper*/ false,
        /*left*/ true,
        /*unitriangular*/ true);
    auto U = LU.triu();
    R = -at::linalg_solve_triangular(
        U, dLU.triu() + R.matmul(U), /*upper*/ true);
    // dX = op_2(R op_2(X)) + S
    return (left ? R.matmul(X) : X.matmul(R.mH())) + S;
  } else {
    // We see that when left == adjoint, op_1(A) = A^H
    // dX = op_2(op_1(-op_3(B)^H U^{-1}(dUU^{-1} + L^{-1}dL)L^{-1} P^T)) + S
    // Now, note that whenever adjoint == left, we have that
    // op_3(B)^H A^{-1} = op_3(X)^H
    // We can then rewrite the formula above in terms of X as
    // dX = op_2(op_1(-op_3(X)^H P(LdUU^{-1} + dL)L^{-1} P^T)) + S
    auto [P, L, U] = at::lu_unpack(LU, pivots);
    // Compute V = op_3(X)^H
    auto V = left ? X.mH() : X;
    // Compute the inner parens LdUU^{-1} + dL
    auto R = at::linalg_solve_triangular(
                 U, L.matmul(dLU.triu()), /*upper*/ true, /*left*/ false) +
        dLU.tril(-1);
    // dX = op_2(op_1(-op_3(X)^H PRL^{-1} P^T)) + S
    R = at::linalg_solve_triangular(
            L,
            -V.matmul(P).matmul(R),
            /*upper*/ false,
            /*left*/ false,
            /*unitriangular*/ true)
            .matmul(P.mT());
    // dX = op_2(R^H) + S
    return (left ? R.mH() : std::move(R)) + S;
  }
}

Tensor linalg_solve_jvp(
    const Tensor& dA,
    const Tensor& dB,
    const Tensor& X,
    const Tensor& LU,
    const Tensor& pivots,
    const bool left,
    const bool use_A_T) {
  at::NoTF32Guard disable_tf32;
  // For left=True (left=False is analogous)
  // dX = A^{-1}(dB - dAX)

  // [NumPy compat] Case where the rhs is a vector.
  // We denote with an underscore vectors that have been converted to matrices
  // by `unsqueeze(-1)`
  const bool vector_case = at::native::linalg_solve_is_vector_rhs(LU, X);
  const auto vector_to_matrix = [vector_case](const Tensor& X) {
    return vector_case ? X.unsqueeze(-1) : X;
  };
  const auto matrix_to_vector = [vector_case](const Tensor& X) {
    return vector_case ? X.squeeze(-1) : X;
  };

  // This case is disallowed in the primal operation as A.shape = (*, 1, 1)
  TORCH_INTERNAL_ASSERT(left || !vector_case);

  auto X_ = vector_to_matrix(X);
  auto dB_ = vector_to_matrix(dB);
  auto R_ = left ? dA.matmul(X_) : X_.matmul(dA);
  auto dX_ =
      at::linalg_lu_solve(LU, pivots, dB_ - R_, left, /*adjoint*/ use_A_T);
  return matrix_to_vector(dX_);
}

std::tuple<Tensor, Tensor> linalg_solve_backward(
    const Tensor& gX,
    const Tensor& X,
    const Tensor& A,
    const Tensor& LU,
    const Tensor& pivots,
    const bool left,
    const bool B_requires_grad) {
  // for X = A^{-1}B
  // gB = A^{-H}gX
  // gA = -gB X^H
  at::NoTF32Guard disable_tf32;
  const auto A_requires_grad = A.requires_grad();
  if (!gX.defined() || (!A_requires_grad && !B_requires_grad)) {
    return {};
  }

  // [NumPy compat] Case where the rhs is a vector.
  // We denote with an underscore vectors that have been converted to matrices
  // by `unsqueeze(-1)`
  const bool vector_case = at::native::linalg_solve_is_vector_rhs(LU, X);
  const auto vector_to_matrix = [vector_case](const Tensor& X) {
    return vector_case ? X.unsqueeze(-1) : X;
  };
  const auto matrix_to_vector = [vector_case](const Tensor& X) {
    return vector_case ? X.squeeze(-1) : X;
  };

  // If the user is going to compute higher order gradients, then we need to
  // recompute the LU and the pivots
  Tensor gB_;
  if (at::GradMode::is_enabled()) {
    gB_ = at::linalg_solve(A.mH(), vector_to_matrix(gX), left);
  } else {
    const auto use_A_T = A.is_contiguous() && !A.is_complex();
    gB_ = at::linalg_lu_solve(
        LU, pivots, vector_to_matrix(gX), left, /*adjoint*/ !use_A_T);
  }

  Tensor gA_;
  if (A_requires_grad) {
    auto X_ = vector_to_matrix(X);
    gA_ = left ? -gB_.matmul(X_.mH()) : -X_.mH().matmul(gB_);
  }
  return std::make_tuple(
      A_requires_grad ? std::move(gA_) : Tensor{},
      B_requires_grad ? matrix_to_vector(gB_) : Tensor{});
}

Tensor solve_jvp(
    const Tensor& X,
    const Tensor& A,
    const Tensor& dA,
    const Tensor& dB) {
  return generic_solve_jvp(
      [](const Tensor& A, const Tensor& dB, const Tensor& dA_contrib) {
        return at::linalg_solve(A, dB - dA_contrib);
      },
      X,
      A,
      dA,
      dB);
}

Tensor lu_unpack_backward(
    const Tensor& L_grad,
    const Tensor& U_grad,
    const c10::SymInt& m,
    const c10::SymInt& n) {
  if (!L_grad.defined() && !U_grad.defined()) {
    return {};
  }
  const auto k = std::min(m, n);

  // Getters for the principal and complementary part of the matrices
  const auto get_L1 = [m, k](const Tensor& L) {
    return m == k ? L.tril(-1) : L.narrow_symint(-2, 0, k).tril(-1);
  };
  const auto get_L2 = [m, k](const Tensor& L) {
    return L.narrow_symint(-2, k, m - k);
  };
  const auto get_U1 = [n, k](const Tensor& U) {
    return n == k ? U.triu() : U.narrow_symint(-1, 0, k).triu();
  };
  const auto get_U2 = [n, k](const Tensor& U) {
    return U.narrow_symint(-1, k, n - k);
  };

  if (L_grad.defined()) {
    if (U_grad.defined()) {
      if (m == n) {
        return L_grad.tril(-1) + U_grad.triu();
      } else {
        auto A1_grad = get_L1(L_grad) + get_U1(U_grad);
        auto A2_grad = m > n ? get_L2(L_grad) : get_U2(U_grad);
        const auto dim = m > n ? -2 : -1;
        return at::cat({std::move(A1_grad), std::move(A2_grad)}, /*dim=*/dim);
      }
    } else {
      if (m >= n) {
        return L_grad.tril(-1);
      } else {
        auto size = L_grad.sym_sizes().vec();
        size.end()[-1] = n - m;
        return at::cat(
            {L_grad.tril(-1), at::zeros_symint(size, L_grad.options())},
            /*dim=*/-1);
      }
    }
  } else {
    if (n >= m) {
      return U_grad.triu();
    } else {
      auto size = U_grad.sym_sizes().vec();
      size.end()[-2] = m - n;
      return at::cat(
          {U_grad.triu(), at::zeros_symint(size, U_grad.options())},
          /*dim=*/-2);
    }
  }
}

Tensor cat_jvp(const at::ITensorListRef& tensors, int64_t dim) {
  Tensor out_fw_grad;

  auto materialized = tensors.materialize();
  auto any_defined = false;
  for (const Tensor& t : materialized) {
    any_defined |= isFwGradDefined(t);
  }

  if (any_defined) {
    std::vector<Tensor> fw_grads;

    for (const Tensor& t : materialized) {
      fw_grads.push_back(
          isFwGradDefined(t)
              ? t._fw_grad(/*level*/ 0)
              : at::_efficientzerotensor(t.sizes(), t.options()));
    }

    out_fw_grad = at::cat(fw_grads, dim);
  }

  return out_fw_grad;
}

Tensor block_diag_jvp(at::TensorList tensors) {
  Tensor out_fw_grad;

  auto any_defined = false;
  for (const auto& t : tensors) {
    any_defined |= isFwGradDefined(t);
  }

  if (any_defined) {
    std::vector<Tensor> fw_grads;
    fw_grads.reserve(tensors.size());

    for (const auto& t : tensors) {
      fw_grads.push_back(
          isFwGradDefined(t)
              ? t._fw_grad(/*level*/ 0)
              : at::_efficientzerotensor(t.sizes(), t.options()));
    }

    out_fw_grad = at::block_diag(fw_grads);
  }

  return out_fw_grad;
}

Tensor stack_jvp(at::TensorList tensors, int64_t dim) {
  // Basically copy of cat_jvp above
  // TODO: consolidate with the logic of cat_jvp
  Tensor out_fw_grad;

  auto any_defined = false;
  for (const auto& t : tensors) {
    any_defined |= isFwGradDefined(t);
  }

  if (any_defined) {
    std::vector<Tensor> fw_grads;

    for (auto& t : tensors) {
      fw_grads.push_back(
          isFwGradDefined(t)
              ? t._fw_grad(/*level*/ 0)
              : at::_efficientzerotensor(t.sizes(), t.options()));
    }
    out_fw_grad = at::stack(fw_grads, dim);
  }
  return out_fw_grad;
}

Tensor cumprod_jvp(
    const Tensor& self_t,
    const Tensor& self_p,
    const Tensor& result,
    int dim) {
  // Generic formula when no 0. is involved
  Tensor gradient = (self_t / self_p).cumsum(dim) * result;

  // Note that we have to use at::where below as we are removing nans

  if (self_p.dim() == 0) {
    gradient.masked_fill_(self_p.eq(0), self_t);
    return gradient;
  } else {
    // For input (a, 0, b, 0, c) with gradients (t0, t1, t2, t3, t4)
    // The output of cumprod is (a, 0, 0, 0, 0)
    // The gradient we want to compute is (t0, a*t1, a*b*t1, 0, 0)
    // We do this by:
    // Get a mask of all zeros (0, 1, 0, 1, 0)
    auto mask_zeros = self_p.eq(0);
    // Get a mask of the first zero for each dim (0, 1, 0, 0, 0)
    auto mask_first_zero = mask_zeros.logical_and(mask_zeros.cumsum(dim).eq(1));

    // Get the new grad value that should be used after any zero happened:
    // (X, a*t1, a*b*t1, 0, 0) = cumprod((a, t1, b, 0, c))
    auto new_grad = at::where(mask_first_zero, self_t, self_p).cumprod(dim);

    // Get a mask of everything after the first zero: (0, 1, 1, 1, 1)
    auto mask_after_first_zero = mask_first_zero.cumsum(dim);

    // Do the final replacement
    return at::where(
        mask_after_first_zero.to(ScalarType::Bool), new_grad, gradient);
  }
}

// Helper for {batch,layer,group}_norms below
// Computes the jvp for `1 / input.std(dims, keepdim)`
static Tensor _invstd_jvp(
    const Tensor& input_p,
    const Tensor& input_t,
    const Tensor& mean_p,
    const Tensor& invstd_p,
    IntArrayRef dims,
    int64_t numel,
    bool keepdim) {
  Tensor invstd_t;
  if (areAnyTensorSubclassLike({input_t, input_p, mean_p, invstd_p}) ||
      input_t._is_zerotensor()) {
    invstd_t = -invstd_p.pow(3) * (input_t - input_t.mean(dims, true)) *
        (input_p - mean_p);
  } else {
    invstd_t = input_t - input_t.mean(dims, true);
    invstd_t *= input_p - mean_p;
    invstd_t *= -invstd_p.pow(3);
  }
  invstd_t = invstd_t.sum(dims, keepdim);
  invstd_t /= numel;
  return invstd_t;
}

// Helper for {batch,layer,group}_norms below only
// Computes the jvp for `(input - input.mean(dims)) * input.invstd(dims)`
static Tensor _norm_jvp(
    const Tensor& input_p,
    const Tensor& input_t,
    const Tensor& mean_p,
    const Tensor& invstd_p,
    IntArrayRef dims,
    int64_t numel) {
  auto invstd_t =
      _invstd_jvp(input_p, input_t, mean_p, invstd_p, dims, numel, true);
  Tensor result_t;
  if (areAnyTensorSubclassLike({input_t, input_p, mean_p, invstd_p}) ||
      input_t._is_zerotensor()) {
    result_t = (input_t - input_t.mean(dims, true)) * invstd_p +
        (input_p - mean_p) * invstd_t;
  } else {
    result_t = input_t - input_t.mean(dims, true);
    result_t *= invstd_p;
    auto temp = input_p - mean_p;
    temp *= invstd_t;
    result_t += temp;
  }
  return result_t;
}

// Helper for {batch,layer,group}_norms below only
// Computes the jvp for `input * weight + bias` where weight and bias may be
// undefined Possibly modifies the input inplace
static Tensor _affine_jvp(
    const std::optional<Tensor>& input_p,
    Tensor& input_t,
    const Tensor& weight_p,
    const Tensor& weight_t,
    const Tensor& bias_t) {
  // We allow input_p to be optional because if weight_p isn't defined,
  // it may be possible to avoid computing input_p
  TORCH_INTERNAL_ASSERT(input_p.has_value() == weight_p.defined());
  if (weight_p.defined()) {
    if (areAnyTensorSubclassLike(
            // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
            {input_p.value(), input_t, weight_p, weight_t}) ||
        input_t._is_zerotensor() || weight_t._is_zerotensor()) {
      // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
      input_t = input_t * weight_p + input_p.value() * weight_t;
    } else {
      input_t *= weight_p;
      // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
      auto temp = input_p.value();
      temp *= weight_t;
      input_t += temp;
    }
  }
  if (bias_t.defined()) {
    if (areAnyTensorSubclassLike({input_t, bias_t}) ||
        input_t._is_zerotensor()) {
      input_t = input_t + bias_t;
    } else {
      input_t += bias_t;
    }
  }
  return input_t;
}

Tensor batch_norm_jvp(
    const Tensor& input_p,
    const Tensor& input_t,
    const Tensor& weight_p,
    const Tensor& weight_t,
    const Tensor& bias_p,
    const Tensor& bias_t,
    const std::optional<Tensor>& running_mean,
    const std::optional<Tensor>& running_var,
    const Tensor& saved_mean,
    const Tensor& saved_invstd,
    bool train,
    double eps) {
  auto dims = std::vector<int64_t>{};
  auto view_size = input_t.sizes().vec();
  int64_t numel = 1;
  for (const auto dim : c10::irange(view_size.size())) {
    if (dim != 1) {
      numel *= input_t.size(static_cast<int64_t>(dim));
      view_size[dim] = 1;
      dims.push_back(static_cast<int64_t>(dim));
    }
  }
  Tensor mean_p;
  Tensor invstd_p;
  Tensor result_t;
  if (train) {
    mean_p = saved_mean.view(view_size);
    invstd_p = saved_invstd.view(view_size);
    result_t = _norm_jvp(input_p, input_t, mean_p, invstd_p, dims, numel);
  } else {
    TORCH_INTERNAL_ASSERT(
        running_mean.has_value() && running_var.has_value(),
        "Expect running_mean and running_var to have value when train=false");
    TORCH_CHECK(
        !running_mean.value()._fw_grad(/*level=*/0).defined() &&
            !running_var.value()._fw_grad(/*level=*/0).defined(),
        "batch_norm is not differentiable wrt running_mean and running_var, they cannot have forward grad defined");
    mean_p = running_mean.value().view(view_size);
    invstd_p =
        (1 / at::sqrt(running_var.value() + at::Scalar(eps))).view(view_size);
    result_t = input_t * invstd_p;
  }

  std::optional<Tensor> result_p = weight_p.defined()
      ? std::optional<Tensor>((input_p - mean_p) * invstd_p)
      : std::nullopt;
  return _affine_jvp(
      result_p,
      result_t,
      weight_p.defined() ? weight_p.view(view_size) : weight_p,
      weight_t.defined() ? weight_t.view(view_size) : weight_t,
      bias_t.defined() ? bias_t.view(view_size) : bias_t);
}

Tensor layer_norm_jvp(
    const Tensor& input_p,
    const Tensor& input_t,
    const Tensor& weight_p,
    const Tensor& weight_t,
    const Tensor& bias_p,
    const Tensor& bias_t,
    const Tensor& saved_mean,
    const Tensor& saved_invstd,
    c10::SymIntArrayRef normalized_shape) {
  auto dims = std::vector<int64_t>{};
  auto view_size = input_t.sizes().vec();
  auto view_size_affine = input_t.sizes().vec();

  int64_t numel = 1;
  for (const auto i : c10::irange(view_size.size())) {
    if (i < view_size.size() - normalized_shape.size()) {
      view_size_affine[i] = 1;
    } else {
      numel *= input_t.size(static_cast<int64_t>(i));
      view_size[i] = 1;
      dims.push_back(static_cast<int64_t>(i));
    }
  }
  auto mean_p = saved_mean.view(view_size);
  auto invstd_p = saved_invstd.view(view_size);
  auto result_t = _norm_jvp(input_p, input_t, mean_p, invstd_p, dims, numel);

  std::optional<Tensor> result_p = weight_p.defined()
      ? std::optional<Tensor>((input_p - mean_p) * invstd_p)
      : std::nullopt;
  return _affine_jvp(
      result_p,
      result_t,
      weight_p.defined() ? weight_p.view(view_size_affine) : weight_p,
      weight_t.defined() ? weight_t.view(view_size_affine) : weight_t,
      bias_t.defined() ? bias_t.view(view_size_affine) : bias_t);
}

Tensor group_norm_jvp(
    const Tensor& input_p,
    const Tensor& input_t,
    const Tensor& weight_p,
    const Tensor& weight_t,
    const Tensor& bias_p,
    const Tensor& bias_t,
    const Tensor& saved_mean,
    const Tensor& saved_invstd,
    int64_t groups) {
  auto input_shape = input_p.sizes();
  int64_t N = input_p.size(0);
  int64_t C = input_p.size(1);

  auto input_t_reshaped = input_t.view({1, N * groups, N ? -1 : 1});
  auto input_p_reshaped = input_p.view({1, N * groups, N ? -1 : 1});

  auto result_t = batch_norm_jvp(
                      input_p_reshaped,
                      input_t_reshaped,
                      /*weight_p=*/{},
                      /*weight_t=*/{},
                      /*bias_p=*/{},
                      /*bias_t=*/{},
                      /*running_mean=*/{},
                      /*running_var=*/{},
                      saved_mean,
                      saved_invstd,
                      /*train=*/true,
                      /*eps=*/0)
                      .view(input_shape);

  std::optional<Tensor> result_p = std::nullopt;
  if (weight_p.defined()) {
    std::vector<int64_t> view_size(input_t_reshaped.dim(), 1);
    view_size[1] = input_t_reshaped.size(1);
    result_p = ((input_p_reshaped - saved_mean.view(view_size)) *
                saved_invstd.view(view_size))
                   .view(input_shape);
  }
  std::vector<int64_t> affine_param_shape(input_p.dim(), 1);
  affine_param_shape[1] = C;

  return _affine_jvp(
      result_p,
      result_t,
      weight_p.defined() ? weight_p.view(affine_param_shape) : weight_p,
      weight_t.defined() ? weight_t.view(affine_param_shape) : weight_t,
      bias_t.defined() ? bias_t.view(affine_param_shape) : bias_t);
}

Tensor group_norm_mean_jvp(
    const Tensor& input_t,
    const Tensor& mean_p,
    int64_t groups) {
  int64_t N = input_t.size(0);
  std::array<int64_t, 3> view_shape = {1, N * groups, N ? -1 : 1};
  auto input_t_reshaped = input_t.view(view_shape);
  return input_t_reshaped.mean({2}, false).view_as(mean_p);
}

Tensor group_norm_invstd_jvp(
    const Tensor& input_p,
    const Tensor& input_t,
    const Tensor& mean_p,
    const Tensor& invstd_p,
    int64_t groups) {
  int64_t N = input_p.size(0);

  std::vector<int64_t> view_shape = {1, N * groups, N ? -1 : 1};

  auto input_t_reshaped = input_t.view(view_shape);
  auto input_p_reshaped = input_p.view(view_shape);

  return _invstd_jvp(
             input_t_reshaped,
             input_p_reshaped,
             mean_p.view(view_shape),
             invstd_p.view(view_shape),
             /*dims=*/{2},
             /*numel=*/input_t_reshaped.size(2),
             /*keepdim=*/false)
      .view_as(invstd_p);
}

Tensor gather_with_keepdimed_indices(
    const Tensor& input,
    int64_t dim,
    const Tensor& indices,
    bool keepdim) {
  auto full_indices = indices;
  if (!keepdim) {
    full_indices = indices.unsqueeze(dim);
  }
  auto out_fw_grad = at::gather(input, dim, full_indices);
  if (!keepdim) {
    out_fw_grad = out_fw_grad.squeeze(dim);
  }

  return out_fw_grad;
}

// Let A in \C^{m \times n}, then its pivoted LU decomposition is
// A = P L U, where P is a permutation matrix.
//
// Useful notation:
// Let o denote the elementwise, or Hadamard, product.
// k := min(m, n)
// 1 := ones(k, k),
// 1_U = 1.tril();
// 1_L = 1 - 1_U (note the diagonal is zero)
// For a matrix A, A^H := A.mH()
//
// Below we derive the backward algorithm for the case when m <= n.
// The case m > n could be obtained using the same idea.
// Since we assume m <= n, the LU decomposition of A could be written as
// A = (A1 | A2) = P L (U1 | U2) where A1, U1 in \C^{m \times m}, A2, U2 in
// \C^{m, n - m}
//
// Forward AD:
//
// dA = P dL U + P L dU => [left-multiply P^T]
// (P^T dA1 | P^T dA2) = (dL U1 + L dU1 | dL U2 + L dU2) (*)
// From (*):
// P^T dA1 = dL U1 + L dU1 => [left-multiply by L^{-1}, right-multiply by
// U1^{-1}] L^{-1} P^T dA1 U1^{-1} = L^{-1} dL + dU1 U1^{-1} (**). Note, L is
// lower-triangular, and so is its inverse, hence L^{-1} dL is lower-triangular.
// Also, since the diagonal of L (all ones) is never exposed explicitly (packed
// representation), the diagonal of dL is zero, and hence diag(L^{-1} dL) = 0.
// Assuming that U1 is full-rank, similarly, dU1 U1^{-1} is upper-triangular.
// Combining these observations we conclude:
//
// L^{-1} dL = (L^{-1} P^T dA1 U1^{-1}) o 1_L,
// dU1 U1^{-1} = (L^{-1} P^T dA1 U1^{-1}) o 1_U.
//
// Hence,
// dL = L [(L^{-1} P^T dA1 U1^{-1}) o 1_L],
// dU1 = [(L^{-1} P^T dA1 U1^{-1}) o 1_U] U1.
// As for dU2, from (*) it follows
// P^T dA2 = dL U2 + L dU2 =>
// dU2 = L^{-1} (P^T dA2 - dL U2).
//
// Backward AD:
//
// The following equality comes very handy:
// Tr(A (B o C)) = Tr((A o B^T) C) (!)
// or in other words, given that X -> B o X is a pointwise operation
// its Jacobian is diagonal, so its differential is self-adjoint
// <A, B o C> = <A o B, C>
//
// Tr(A_grad^H dA) = Tr(L_grad^H dL) + Tr(U_grad^H dU), then
//
// Tr(L_grad^H dL) = Tr(L_grad^H L [(L^{-1} P^T dA1 U1^{-1}) o 1_L] = [using
// (!)]
//                 = Tr((L_grad^H L o 1_L^T) L^{-1} P^T dA1 U1^{-1}) = [using
//                 the cyclic property of Tr] = Tr(U1^{-1} (L_grad^H L o 1_L^T)
//                 L^{-1} P^T dA1)
//
// Similar, using (!) and the cyclic property of the trace operator:
// Tr(U_grad^H dU) = Tr(U1_grad^H dU1) + Tr(U2_grad^H dU2)
//                 = Tr(U1^{-1} (U1 U1_grad^H o 1_U^T) L^{-1} P^T dA1)
//                   + Tr(U2_grad^H L^{-1} P^T dA2)
//                   - Tr(U1^{-1} (U2 U2_grad^H o 1_L^T) L^{-1} P^T dA1)
//
// By combining the matrices to the left from dA1 and dA2 and then applying
// conjugate transposition, we finally arrive at:
//
// A1_grad = P L^{-H} [L^H L_grad o 1_L + U1_grad U1^H o 1_U - U2_grad U2^H o
// 1_L] U1^{-H}, A2_grad = P L^{-H} U2_grad
Tensor linalg_lu_backward(
    const Tensor& L_grad,
    const Tensor& U_grad,
    const Tensor& P,
    const Tensor& L,
    const Tensor& U,
    const bool pivot) {
  at::NoTF32Guard disable_tf32;
  // Return early if there's nothing to do
  if (!L_grad.defined() && !U_grad.defined()) {
    return {};
  }

  // L.shape == (..., m, k)
  // U.shape == (..., k, n)
  auto m = L.sym_size(-2);
  auto n = U.sym_size(-1);
  auto k = std::min(m, n);

  if (m == n) {
    // A_grad = P L^{-H} [L^H L_grad o 1_L + U_grad U^H o 1_U] U^{-H},
    auto A_grad = L_grad.defined() ? L.mH().matmul(L_grad).tril(-1) : Tensor{};
    if (U_grad.defined()) {
      A_grad = A_grad.defined() ? A_grad + U_grad.matmul(U.mH()).triu()
                                : U_grad.matmul(U.mH()).triu();
    }
    A_grad = at::linalg_solve_triangular(
        U.mH(),
        A_grad,
        /*upper=*/false,
        /*left=*/false);
    A_grad = at::linalg_solve_triangular(
        L.mH(),
        A_grad,
        /*upper=*/true,
        /*left=*/true,
        /*unitriangular=*/true);

    return pivot ? P.matmul(A_grad) : std::move(A_grad);
  } else if (m < n) {
    // Wide case
    // A1_grad = P L^{-H} [U1_grad + (L^H L_grad o 1_L - U_grad U^H o 1_U)
    // U1^{-H}) U^{-H}] A2_grad = P L^{-H}  U2_grad
    const auto get_U1 = [n, k](const Tensor& U) {
      return n == k ? U : U.narrow_symint(-1, 0, k);
    };
    const auto get_U2 = [n, k](const Tensor& U) {
      return U.narrow_symint(-1, k, n - k);
    };

    auto A_grad = L_grad.defined() ? L.mH().matmul(L_grad) : Tensor{};
    if (U_grad.defined()) {
      A_grad = A_grad.defined() ? A_grad - U_grad.triu().matmul(U.mH())
                                : -U_grad.triu().matmul(U.mH());
    }
    A_grad = at::linalg_solve_triangular(
        get_U1(U).mH(),
        A_grad.tril(-1),
        /*upper=*/false,
        /*left=*/false);

    if (U_grad.defined()) {
      A_grad =
          at::cat({A_grad + get_U1(U_grad).triu(), get_U2(U_grad)}, /*dim=*/-1);
    }

    A_grad = at::linalg_solve_triangular(
        L.mH(),
        A_grad,
        /*upper=*/true,
        /*left=*/true,
        /*unitriangular=*/true);

    if (!U_grad.defined()) {
      A_grad = at::cat({A_grad, at::zeros_like(get_U2(U))}, /*dim=*/-1);
    }
    if (pivot) {
      A_grad = P.matmul(A_grad);
    }
    return A_grad;
  } else {
    // Tall case
    // A1_grad = P [L1_grad + L^{-H} (U_grad U^H o 1_U - L^H L_grad o
    // 1_L)]U^{-H} A2_grad = P  L2_grad U^{-H}

    const auto get_L1 = [m, k](const Tensor& L) {
      return m == k ? L : L.narrow_symint(-2, 0, k);
    };
    const auto get_L2 = [m, k](const Tensor& L) {
      return L.narrow_symint(-2, k, m - k);
    };

    auto A_grad = U_grad.defined() ? U_grad.matmul(U.mH()) : Tensor{};
    if (L_grad.defined()) {
      A_grad = A_grad.defined() ? A_grad - L.mH().matmul(L_grad.tril(-1))
                                : -L.mH().matmul(L_grad.tril(-1));
    }
    A_grad = at::linalg_solve_triangular(
        get_L1(L).mH(),
        A_grad.triu(),
        /*upper=*/true,
        /*left=*/true,
        /*unitriangular=*/true);

    if (L_grad.defined()) {
      A_grad = at::cat(
          {A_grad + get_L1(L_grad).tril(-1), get_L2(L_grad)}, /*dim=*/-2);
    }

    A_grad = at::linalg_solve_triangular(
        U.mH(),
        A_grad,
        /*upper=*/false,
        /*left=*/false);

    if (!L_grad.defined()) {
      A_grad = at::cat({A_grad, at::zeros_like(get_L2(L))}, /*dim=*/-2);
    }
    if (pivot) {
      A_grad = P.matmul(A_grad);
    }
    return A_grad;
  }
}

Tensor lu_factor_ex_backward(
    const Tensor& grad,
    const Tensor& LU,
    const Tensor& pivs,
    const bool pivot) {
  auto [P, L, U] =
      at::lu_unpack(LU, pivs, /*unpack_data=*/true, /*unpack_pivots*/ pivot);

  // L.shape == (..., m, k)
  // U.shape == (..., k, n)
  const auto m = LU.sym_size(-2);
  const auto n = LU.sym_size(-1);
  const auto k = std::min(m, n);
  const auto L_grad = grad.narrow_symint(-1, 0, k);
  const auto U_grad = grad.narrow_symint(-2, 0, k);
  return linalg_lu_backward(
      /*L_grad=*/L_grad, /*U_grad=*/U_grad, P, L, U, pivot);
}

// This function is based on the forward AD derivations outlined
// in the description to the linalg_lu_backward function.
std::tuple<Tensor, Tensor> linalg_lu_jvp(
    const Tensor& dA,
    const Tensor& P,
    const Tensor& L,
    const Tensor& U,
    const bool pivot) {
  at::NoTF32Guard disable_tf32;

  auto m = dA.size(-2);
  auto n = dA.size(-1);
  auto k = std::min(m, n);

  auto PdA = pivot ? P.mT().matmul(dA) : dA;

  // similar to the backward implementation, we also consider block structures
  // such as: for a matrix A of size m x n we decompose it as A = (A1 | A2) with
  // A1 of size m x m if m <= n and A = (A1^T | A2^T)^T with A1 of size n x n if
  // m > n.
  auto PdA1 = PdA.narrow(-2, 0, k).narrow(-1, 0, k);
  auto L1 = L.narrow(-2, 0, k).narrow(-1, 0, k);
  auto U1 = U.narrow(-2, 0, k).narrow(-1, 0, k);

  // We form using two triangular_solve the matrix, the second one in place
  // dK = L1^{-1} PdA1 U2^{-1}
  auto dK = at::linalg_solve_triangular(
      L1, PdA1, /*upper=*/false, /*left=*/true, /*unitriangular*/ true);

  // TODO We should be able to do this in-place. At the moment it raises:
  //  RuntimeError: linalg_solve_triangular(): functions with out=...
  //  arguments don't support automatic differentiation, but one of the
  //  arguments requires grad.

  //  at::linalg_solve_triangular_out(dK, U1, dK, /*upper=*/true,
  //  /*left=*/false);
  dK = at::linalg_solve_triangular(U1, dK, /*upper=*/true, /*left=*/false);

  auto dL1 = L1.matmul(dK.tril(-1));
  auto dU1 = dK.triu().matmul(U1);

  if (m == n) {
    return std::make_tuple(std::move(dL1), std::move(dU1));
  } else if (m < n) {
    // we only need to update dU2 defined as
    // dU2 := L1^{-1} PdA2 - dK.tril(-1) U2)
    const auto PdA2 = PdA.narrow(-1, k, n - k);
    const auto U2 = U.narrow(-1, k, n - k);
    auto dU2 =
        at::linalg_solve_triangular(
            L1, PdA2, /*upper=*/false, /*left=*/true, /*unitriangular*/ true) -
        dK.tril(-1).matmul(U2);
    return std::make_tuple(
        std::move(dL1), at::cat({std::move(dU1), std::move(dU2)}, /*dim=*/-1));
  } else {
    // we only need to update dL2 defined as
    // dL2 := PdA2 U^{-1} - L2 dK.triu()
    const auto PdA2 = PdA.narrow(-2, k, m - k);
    const auto L2 = L.narrow(-2, k, m - k);
    auto dL2 =
        at::linalg_solve_triangular(U1, PdA2, /*upper=*/true, /*left=*/false) -
        L2.matmul(dK.triu());
    return std::make_tuple(
        at::cat({std::move(dL1), std::move(dL2)}, /*dim=*/-2), std::move(dU1));
  }
}

Tensor lu_factor_ex_jvp(
    const Tensor& dA,
    const Tensor& LU,
    const Tensor& pivs,
    const bool pivot) {
  auto [P, L, U] =
      at::lu_unpack(LU, pivs, /*unpack_data=*/true, /*unpack_pivots=*/pivot);
  auto [dL, dU] = linalg_lu_jvp(dA, P, L, U, pivot);

  auto m = dA.size(-2);
  auto n = dA.size(-1);
  if (m >= n) {
    dL.narrow(-2, 0, n).add_(dU);
    return dL;
  } else {
    dU.narrow(-1, 0, m).add_(dL);
    return dU;
  }
}

Tensor logsumexp_jvp(
    const Tensor& self_p,
    const Tensor& self_t,
    IntArrayRef dim,
    bool keepdim) {
  // NB: for simplicity, we recompute some values that can be reused from
  // forward
  auto self_p_exp = [&self_p, &dim]() {
    if (self_p.sym_numel() > 0) {
      return (self_p - at::amax(self_p, dim, true))
          .exp(); // Use the exp-normalize trick
    } else {
      // amax fails if numel() == 0, in which case it doesn't matter anyway
      return self_p.exp();
    }
  }();

  auto sumexp_p = self_p_exp.sum(dim, keepdim);

  // NB: it's OK for logsumexp_jvp to be reused for formulas like
  // softmax/log_softmax
  //     that only have one differentiable input, because that means self_t are
  //     never zerotensors
  TORCH_INTERNAL_ASSERT(!self_t._is_zerotensor())
  if (areAnyTensorSubclassLike({self_p, self_t})) {
    auto result = (self_p_exp * self_t).sum(dim, keepdim);
    result /= sumexp_p;
    return result;
  } else {
    self_p_exp *= self_t;
    auto sumexp_t = self_p_exp.sum(dim, keepdim);
    return sumexp_t /= sumexp_p;
  }
}

Tensor warn_backwards(const Tensor& grad_output) {
  TORCH_WARN("Warn from backward");
  return grad_output;
}

// This function only exists because cuDNN does not support bias gradient
// computation and it's not easy to slice a std::tuple to return only grad_input
// / grad_weight from convolution_backward. It will be removed when the
// cudnn_convolution and cudnn_convolution_transpose go away.
std::tuple<Tensor, Tensor> _cudnn_convolution_backward(
    const at::Tensor& self,
    const at::Tensor& grad_output,
    const at::Tensor& weight,
    at::SymIntArrayRef padding,
    at::SymIntArrayRef output_padding,
    at::SymIntArrayRef stride,
    at::SymIntArrayRef dilation,
    bool transposed,
    c10::SymInt groups,
    ::std::array<bool, 2> output_mask) {
  if (!grad_output.defined()) {
    return std::tuple<Tensor, Tensor>();
  }

  // Just call the general backward and ignore the bias gradient part.
  std::tuple<Tensor, Tensor, Tensor> grad_inputs =
      at::convolution_backward_symint(
          grad_output,
          self,
          weight,
          std::nullopt,
          stride,
          padding,
          dilation,
          transposed,
          output_padding,
          std::move(groups),
          {output_mask[0], output_mask[1], false});
  std::tuple<Tensor, Tensor> result =
      std::make_tuple(std::get<0>(grad_inputs), std::get<1>(grad_inputs));
  return result;
}

Tensor scatter_reduce_jvp(
    const Tensor& self_p,
    const Tensor& self_t,
    int dim,
    const Tensor& index,
    const Tensor& src_p,
    const Tensor& src_t,
    c10::string_view reduce,
    bool include_self,
    const Tensor& result) {
  if (reduce == "sum" || reduce == "mean") {
    // The function is linear
    return at::scatter_reduce(self_t, dim, index, src_t, reduce, include_self);
    //  auto mask = x == restore_reduced_dims(result, dim, keepdim);
    //  return at::where(mask, dx, 0.).sum(dim, keepdim) / mask.sum(dim,
    //  keepdim);
  } else if (reduce == "amin" || reduce == "amax") {
    auto gather_result = at::gather(result, dim, index);
    auto mask_self = self_p == result;
    auto mask_src = src_p == gather_result;
    auto masked_src_t = at::where(mask_src, src_t, 0.);
    auto div =
        mask_self.to(self_t.dtype())
            .scatter_reduce(
                dim, index, mask_src.to(self_t.dtype()), "sum", include_self);
    return at::where(mask_self, self_t, 0.)
        .scatter_reduce(dim, index, masked_src_t, "sum", include_self)
        .div(div);
  } else {
    // Not implemented
    return Tensor{};
  }
}

std::tuple<Tensor, Tensor> scatter_reduce_backward(
    const Tensor& grad,
    const Tensor& self,
    int dim,
    const Tensor& index,
    const Tensor& src,
    c10::string_view reduce,
    bool include_self,
    const Tensor& result) {
  Tensor grad_self, grad_src;

  // FIXME: complex gradients not handled correctly
  // For now this is ok as scatter_reduce isn't added to the whitelist
  // in tools/autograd/gen_variable_type.py

  if (!grad.defined()) {
    return std::make_tuple(grad_self, grad_src);
  }

  if (reduce == "sum") {
    grad_self = grad;
    grad_src = grad.gather(dim, index);
  } else if (reduce == "prod") {
    // Explicitly compute exclusive prod for elements in self/src that are 0
    Tensor masked_self = self.masked_fill(self == 0, 1);
    Tensor masked_self_result =
        masked_self.scatter_reduce(dim, index, src, reduce, include_self);
    grad_self = grad * masked_self_result / masked_self;
    Tensor src_zero = src == 0;
    Tensor src_num_zeros =
        zeros_like(self)
            .scatter_add(dim, index, src_zero.to(self.dtype()))
            .gather(dim, index);
    Tensor src_single_zero = bitwise_and(src_zero, src_num_zeros == 1);
    // For src positions with src_single_zero, grad * result.gather(dim,index) /
    // src.masked_fill(src_zero, 1) would incorrectly propagate zeros as the
    // gradient
    Tensor masked_src = src.masked_fill(src_single_zero, 1);
    Tensor masked_src_result =
        self.scatter_reduce(dim, index, masked_src, reduce, include_self);
    Tensor grad_src1 = where(
        src_single_zero,
        (grad * masked_src_result).gather(dim, index),
        (grad * result).gather(dim, index) / src.masked_fill(src_zero, 1));
    // GradMode::is_enabled() - adding the autograd Node is a no-op if autograd
    // is disabled; this also avoids having the item() call in the usual case.
    if (GradMode::is_enabled() && (src_num_zeros > 1).any().item<bool>()) {
      auto node = std::make_shared<DelayedError>(
          "scatter_reduce(): Double backward is unsupported for src when >1 zeros in src are scattered to the same position in self",
          /* num inputs */ 1);
      auto result = node->apply({std::move(grad_src1)});
      grad_src = result[0];
    } else {
      grad_src = grad_src1;
    }
  } else if (reduce == "mean") {
    Tensor N = include_self ? ones_like(grad) : zeros_like(grad);
    N = N.scatter_add(dim, index, ones_like(src));
    N.masked_fill_(N == 0, 1);
    grad_self = grad / N;
    Tensor N_src = N.gather(dim, index);
    grad_src = grad.gather(dim, index) / N_src;
  } else if (reduce == "amax" || reduce == "amin") {
    // Evenly distribute gradient when there are multiple max/mins
    Tensor value = result.gather(dim, index);
    Tensor self_is_result = (self == result).to(self.scalar_type());
    Tensor src_is_result = (src == value).to(self.scalar_type());
    Tensor N_to_distribute =
        self_is_result.scatter_add(dim, index, src_is_result);
    Tensor grad_distributed = grad / N_to_distribute;
    grad_self = (self == result) * grad_distributed;
    grad_src = (src == value) * grad_distributed.gather(dim, index);
  } else {
    AT_ERROR(
        "Expected 'reduce' to be one of 'sum', 'prod', 'mean', 'amax', 'amin' but got ",
        reduce,
        ".");
  }

  if (!include_self) {
    grad_self = grad_self.scatter(dim, index, 0);
  }

  return std::make_tuple(grad_self, grad_src);
}

Tensor _to_copy_backward(
    const Tensor& grad_,
    const c10::TensorOptions& self_options) {
  // Handle R->C copies without raising a warning
  const auto self_type = self_options.dtype().toScalarType();
  auto grad = c10::MaybeOwned<at::Tensor>::borrowed(grad_);
  if (!c10::isComplexType(self_type) && grad->is_complex()) {
    grad = c10::MaybeOwned<at::Tensor>::owned(at::real(grad_));
  }

  return grad->to(self_options, /*non_blocking=*/false, /*copy=*/false);
}

std::tuple<Tensor, Tensor> index_reduce_backward(
    const Tensor& grad,
    const Tensor& self,
    int dim,
    const Tensor& index,
    const Tensor& source,
    c10::string_view reduce,
    bool include_self,
    const Tensor& result) {
  Tensor grad_self, grad_src;

  // FIXME: index_add's backward formula has a special case for source.dim == 0
  // but this case seems to throw the error "IndexError: dimension specified as
  // 0 but tensor has no dimensions" look into whether this case is reachable
  // and should be covered here

  if (!grad.defined()) {
    return std::make_tuple(grad_self, grad_src);
  }

  if (reduce == "prod") {
    Tensor masked_self = self.masked_fill(self == 0, 1);
    Tensor masked_self_result =
        masked_self.index_reduce(dim, index, source, reduce, include_self);
    grad_self = grad * masked_self_result / masked_self;
    Tensor src_zero = source == 0;
    Tensor src_num_zeros = zeros_like(self)
                               .index_add(dim, index, src_zero.to(self.dtype()))
                               .index_select(dim, index);
    Tensor src_single_zero = bitwise_and(src_zero, src_num_zeros == 1);
    // For src positions with src_single_zero, (grad *
    // result).index_select(dim,index) / source.masked_fill(src_zero, 1) would
    // incorrectly propagate zeros as the gradient
    Tensor masked_src = source.masked_fill(src_single_zero, 1);
    Tensor masked_src_result =
        self.index_reduce(dim, index, masked_src, reduce, include_self);
    Tensor grad_src1 = where(
        src_single_zero,
        (grad * masked_src_result).index_select(dim, index),
        (grad * result).index_select(dim, index) /
            source.masked_fill(src_zero, 1));
    // GradMode::is_enabled() - adding the autograd Node is a no-op if autograd
    // is disabled this also avoids having the item() call in the usual case
    if (GradMode::is_enabled() && (src_num_zeros > 1).any().item<bool>()) {
      auto node = std::make_shared<DelayedError>(
          "index_reduce(): Double backward is unsupported for source when >1 zeros in source are scattered to the same position in self",
          /* num inputs */ 1);
      auto result = node->apply({std::move(grad_src1)});
      grad_src = result[0];
    } else {
      grad_src = grad_src1;
    }
  } else if (reduce == "mean") {
    Tensor N = include_self ? ones_like(grad) : zeros_like(grad);
    N = N.index_add(dim, index, ones_like(source));
    N.masked_fill_(N == 0, 1);
    grad_self = grad / N;
    Tensor N_src = N.index_select(dim, index);
    grad_src = grad.index_select(dim, index) / N_src;
  } else if (reduce == "amax" || reduce == "amin") {
    Tensor value = result.index_select(dim, index);
    Tensor self_is_result = (self == result).to(self.scalar_type());
    Tensor source_is_result = (source == value).to(self.scalar_type());
    Tensor N_to_distribute =
        self_is_result.index_add(dim, index, source_is_result);
    Tensor grad_distributed = grad / N_to_distribute;
    grad_self = self_is_result * grad_distributed;
    grad_src = source_is_result * grad_distributed.index_select(dim, index);
  } else {
    AT_ERROR(
        "Expected 'reduce' to be one of 'prod', 'amax', 'amin' or 'mean' but got ",
        reduce,
        ".");
  }

  if (!include_self) {
    grad_self = grad_self.index_fill(dim, index, 0);
  }

  return std::make_tuple(grad_self, grad_src);
}

Tensor take_backward(
    const Tensor& grad,
    const Tensor& self,
    const Tensor& indices) {
  Tensor grad_self = at::zeros_like(self);
  // For Composite Compliance,
  // if `grad` and `indices` are CCT but `grad_self` is not
  // then we use the out-of-place variant of `put`.
  if (areAnyTensorSubclassLike({grad, indices})) {
    return grad_self.put(indices, grad, true);
  }
  return grad_self.put_(indices, grad, true);
}

Tensor to_sparse_backward(
    const Tensor& grad,
    const c10::Layout self_layout,
    const c10::OptionalArrayRef<c10::SymInt>& self_blocksize) {
  // Path for strided and nested
  if (self_layout == c10::kStrided) {
    return grad.to_dense();
  } else {
    OptionalIntArrayRef blocksize = std::nullopt;
    if (self_blocksize.has_value()) {
      blocksize = c10::asIntArrayRefSlowOpt(*self_blocksize);
    }
    return grad.to_sparse(self_layout, blocksize);
  }
}

std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor>
mkldnn_rnn_layer_differentiable_backward(
    const Tensor& input,
    const Tensor& weight0,
    const Tensor& weight1,
    const Tensor& weight2,
    const Tensor& weight3,
    const Tensor& hx_,
    const Tensor& cx_tmp,
    const Tensor& output,
    const Tensor& hy_,
    const Tensor& cy_,
    const std::optional<Tensor>& grad_output_r_opt,
    const std::optional<Tensor>& grad_hy_r_opt,
    const std::optional<Tensor>& grad_cy_r_opt,
    bool reverse,
    int64_t mode,
    int64_t hidden_size,
    int64_t num_layers,
    bool has_biases,
    bool train,
    bool bidirectional,
    at::IntArrayRef batch_sizes,
    bool batch_first,
    const at::Tensor& workspace) {
  const Tensor& grad_output_r =
      c10::value_or_else(grad_output_r_opt, [] { return Tensor(); });
  const Tensor& grad_hy_r =
      c10::value_or_else(grad_hy_r_opt, [] { return Tensor(); });
  const Tensor& grad_cy_r =
      c10::value_or_else(grad_cy_r_opt, [] { return Tensor(); });
  if (!grad_output_r.defined() && !grad_hy_r.defined() &&
      !grad_cy_r.defined()) {
    return std::make_tuple(
        Tensor(), Tensor(), Tensor(), Tensor(), Tensor(), Tensor(), Tensor());
  }
  auto grad_output = grad_output_r.defined()
      ? grad_output_r.contiguous()
      : at::zeros_like(output, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  auto grad_hy = grad_hy_r.defined()
      ? grad_hy_r.contiguous()
      : at::zeros_like(hx_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  auto grad_cy = cx_tmp.defined()
      ? (grad_cy_r.defined()
             ? grad_cy_r.contiguous()
             : at::zeros_like(cx_tmp, LEGACY_CONTIGUOUS_MEMORY_FORMAT))
      : grad_cy_r.contiguous();
  Tensor bias_ih, bias_hh;
  if (has_biases) {
    bias_ih = weight2;
    bias_hh = weight3;
  } else {
    bias_ih = at::zeros(
        {4 /* num_bias_gates of LSTM */ * hidden_size}, weight0.options());
    bias_hh = at::zeros(
        {4 /* num_bias_gates of LSTM */ * hidden_size}, weight0.options());
  }
  const auto& input_ = input;
  auto hx_prev = hx_;
  auto cx_prev = cx_tmp;

  // Re-calculate gates and hidden states during one layer, which will be used
  // in backward.
  int64_t seq_length = input.size(0);
  std::vector<std::tuple<Tensor, Tensor, Tensor, Tensor>> layer_gates(
      seq_length);
  std::vector<std::tuple<Tensor, Tensor>> layer_states(seq_length + 1);
  layer_states[0] = std::make_tuple(hx_, cx_tmp);
  for (int64_t seq = 1; seq < seq_length + 1; seq++) {
    auto hx = hx_prev;
    auto cx = cx_prev;
    auto x_index = reverse ? seq_length - seq : seq - 1;
    auto gate = at::linear(input_[x_index], weight0, bias_ih)
                    .add_(at::linear(hx, weight1, bias_hh));
    auto chunked_gates = gate.unsafe_chunk(4, 1);
    auto i = chunked_gates[0].sigmoid_();
    auto f = chunked_gates[1].sigmoid_();
    auto g = chunked_gates[2].tanh_();
    auto o = chunked_gates[3].sigmoid_();
    layer_gates[x_index] = std::make_tuple(i, f, g, o);
    auto cy = (f * cx).add(i * g);
    auto hy = o * cy.tanh();
    layer_states[seq] = std::make_tuple(hy, cy);
    hx_prev = hy;
    cx_prev = cy;
  }

  Tensor dx, dWx, dWh, db, db_, dprev_h, dprev_c, dWh_, dWx_;
  Tensor new_grad_hy, d1, dgp, dip, dfp, dop, do_, dg, df, di, da;
  std::vector<at::Tensor> layer_dx(seq_length);
  for (int64_t seq = seq_length - 1; seq >= 0; seq--) {
    int64_t x_index = reverse ? seq_length - seq - 1 : seq;
    auto i = std::get<0>(layer_gates[x_index]);
    auto f = std::get<1>(layer_gates[x_index]);
    auto g = std::get<2>(layer_gates[x_index]);
    auto o = std::get<3>(layer_gates[x_index]);
    auto hy = std::get<0>(layer_states[seq + 1]);
    auto cy = std::get<1>(layer_states[seq + 1]);
    auto hx = std::get<0>(layer_states[seq]);
    auto cx = std::get<1>(layer_states[seq]);
    new_grad_hy = grad_output[x_index].add(grad_hy);
    d1 = grad_cy.add(new_grad_hy * o * (1 - cy.tanh() * cy.tanh()));
    dgp = d1 * i;
    dip = d1 * g;
    dprev_c = d1 * f;
    dfp = d1 * cx;
    dop = new_grad_hy * cy.tanh();
    do_ = dop * o * (1 - o);
    dg = dgp * (1 - g * g);
    df = dfp * f * (1 - f);
    di = dip * i * (1 - i);
    da = at::cat({di, df, dg, do_}, 1);
    db_ = at::sum(da, 0);
    dx = at::matmul(da, weight0);
    dx = at::unsqueeze(dx, 0);
    dprev_h = at::matmul(da, weight1);
    dWx_ = at::matmul(da.transpose(0, 1), input_[x_index]);
    dWh_ = at::matmul(da.transpose(0, 1), hx);
    if (seq == seq_length - 1) {
      db = db_;
      dWx = dWx_;
      dWh = dWh_;
    } else {
      db += db_;
      dWx += dWx_;
      dWh += dWh_;
    }
    layer_dx[x_index] = dx;
    grad_hy = dprev_h;
    grad_cy = dprev_c;
  }

  auto cat_layer_dx = at::cat(layer_dx, 0);
  return std::make_tuple(cat_layer_dx, dWx, dWh, db, db, dprev_h, dprev_c);
}

Tensor values_backward(const Tensor& grad, const Tensor& self) {
  Tensor grad_self;
  if (grad.defined()) {
    if (self.layout() == c10::kSparse) {
      return at::_sparse_coo_tensor_unsafe_symint(
          self.indices(),
          grad,
          self.sym_sizes(),
          self.options(),
          /*is_coalesced=*/true);
    } else if (at::sparse_csr::is_sparse_compressed(self)) {
      auto [compressed_indices, plain_indices] =
          at::sparse_csr::getCompressedPlainIndices(self);
      return at::_sparse_compressed_tensor_unsafe_symint(
          compressed_indices,
          plain_indices,
          grad,
          self.sym_sizes(),
          self.options());
    } else {
      TORCH_CHECK_NOT_IMPLEMENTED(
          false,
          "values backward with respect to self with layout ",
          self.layout());
    }
  }
  return grad_self;
}

} // namespace details
} // namespace generated
} // namespace autograd
} // namespace torch
